diff --git a/CHANGELOG.md b/CHANGELOG.md index d5a01b780..86d0cab0f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. A downstream effect of this several other internal functions need debug info. This change does not affect public APIs. See https://github.com/jax-ml/jax/issues/26480 for more detail. + * In {func}`jax.numpy.ndim`, {func}`jax.numpy.shape`, and {func}`jax.numpy.size`, + non-arraylike inputs (such as lists, tuples, etc.) are now deprecated. * Bug fixes * TPU runtime startup and shutdown time should be significantly improved on diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 1db2e0bde..e281c63ae 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -26,7 +26,7 @@ from jax._src import dtypes from jax._src.lax import lax from jax._src.lib import xla_client as xc from jax._src.sharding_impls import SingleDeviceSharding -from jax._src.util import safe_zip, safe_map +from jax._src.util import safe_zip, safe_map, set_module from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape from jax.sharding import Sharding @@ -35,6 +35,8 @@ import numpy as np zip, unsafe_zip = safe_zip, zip map, unsafe_map = safe_map, map +export = set_module('jax.numpy') + _dtype = partial(dtypes.dtype, canonicalize=True) def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]: @@ -308,3 +310,124 @@ def normalize_device_to_sharding(device: xc.Device | Sharding | None) -> Shardin return SingleDeviceSharding(device) else: return device + + +@export +def ndim(a: ArrayLike) -> int: + """Return the number of dimensions of an array. + + JAX implementation of :func:`numpy.ndim`. Unlike ``np.ndim``, this function + raises a :class:`TypeError` if the input is a collection such as a list or + tuple. + + Args: + a: array-like object. + + Returns: + An integer specifying the number of dimensions of ``a``. + + Examples: + Number of dimensions for arrays: + + >>> x = jnp.arange(10) + >>> jnp.ndim(x) + 1 + >>> y = jnp.ones((2, 3)) + >>> jnp.ndim(y) + 2 + + This also works for scalars: + + >>> jnp.ndim(3.14) + 0 + + For arrays, this can also be accessed via the :attr:`jax.Array.ndim` property: + + >>> x.ndim + 1 + """ + # Deprecation warning added 2025-2-20. + check_arraylike("ndim", a, emit_warning=True) + return np.ndim(a) # NumPy dispatches to a.ndim if available. + + +@export +def shape(a: ArrayLike) -> tuple[int, ...]: + """Return the shape an array. + + JAX implementation of :func:`numpy.shape`. Unlike ``np.shape``, this function + raises a :class:`TypeError` if the input is a collection such as a list or + tuple. + + Args: + a: array-like object. + + Returns: + An tuple of integers representing the shape of ``a``. + + Examples: + Shape for arrays: + + >>> x = jnp.arange(10) + >>> jnp.shape(x) + (10,) + >>> y = jnp.ones((2, 3)) + >>> jnp.shape(y) + (2, 3) + + This also works for scalars: + + >>> jnp.shape(3.14) + () + + For arrays, this can also be accessed via the :attr:`jax.Array.shape` property: + + >>> x.shape + (10,) + """ + # Deprecation warning added 2025-2-20. + check_arraylike("shape", a, emit_warning=True) + return np.shape(a) # NumPy dispatches to a.shape if available. + + +@export +def size(a: ArrayLike, axis: int | None = None) -> int: + """Return number of elements along a given axis. + + JAX implementation of :func:`numpy.size`. Unlike ``np.size``, this function + raises a :class:`TypeError` if the input is a collection such as a list or + tuple. + + Args: + a: array-like object + axis: optional integer along which to count elements. By default, return + the total number of elements. + + Returns: + An integer specifying the number of elements in ``a``. + + Examples: + Size for arrays: + + >>> x = jnp.arange(10) + >>> jnp.size(x) + 10 + >>> y = jnp.ones((2, 3)) + >>> jnp.size(y) + 6 + >>> jnp.size(y, axis=1) + 3 + + This also works for scalars: + + >>> jnp.size(3.14) + 1 + + For arrays, this can also be accessed via the :attr:`jax.Array.size` property: + + >>> y.size + 6 + """ + # Deprecation warning added 2025-2-20. + check_arraylike("size", a, emit_warning=True) + return np.size(a, axis=axis) # NumPy dispatches to a.size if available. diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index d563483a2..ad71b9f74 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -254,6 +254,12 @@ from jax._src.numpy.tensor_contractions import ( vdot as vdot, ) +from jax._src.numpy.util import ( + ndim as ndim, + shape as shape, + size as size, +) + from jax._src.numpy.window_functions import ( bartlett as bartlett, blackman as blackman, @@ -279,15 +285,12 @@ from numpy import ( integer as integer, iterable as iterable, nan as nan, - ndim as ndim, newaxis as newaxis, number as number, object_ as object_, pi as pi, save as save, savez as savez, - shape as shape, - size as size, signedinteger as signedinteger, unsignedinteger as unsignedinteger, ) diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index dee61c145..b73a3b95b 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -728,7 +728,7 @@ def nanvar(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., ddof: int = 0, keepdims: builtins.bool = False, where: ArrayLike | None = ...) -> Array: ... ndarray = Array -ndim = _np.ndim +def ndim(a: ArrayLike) -> int: ... def negative(x: ArrayLike, /) -> Array: ... newaxis = None def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: ... @@ -842,7 +842,7 @@ def setdiff1d( fill_value: ArrayLike | None = ..., ) -> Array: ... def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = ...) -> Array: ... -shape = _np.shape +def shape(a: ArrayLike) -> tuple[int, ...]: ... def sign(x: ArrayLike, /) -> Array: ... def signbit(x: ArrayLike, /) -> Array: ... signedinteger = _np.signedinteger @@ -850,7 +850,7 @@ def sin(x: ArrayLike, /) -> Array: ... def sinc(x: ArrayLike, /) -> Array: ... single: Any def sinh(x: ArrayLike, /) -> Array: ... -size = _np.size +def size(a: ArrayLike, axis: int | None = None) -> int: ... def sort( a: ArrayLike, axis: int | None = ..., diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 50773e23b..98f10d9c0 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6140,6 +6140,42 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol, check_dtypes=False) + @jtu.sample_product( + shape=all_shapes, + dtype=default_dtypes, + op=['ndim', 'shape', 'size'], + ) + def testNdimShapeSize(self, shape, dtype, op): + rng = jtu.rand_default(self.rng()) + jnp_op = getattr(jnp, op) + np_op = getattr(np, op) + x = rng(shape, dtype) + expected = np_op(x) + self.assertEqual(expected, jnp_op(x)) # np.ndarray or scalar input. + self.assertEqual(expected, jnp_op(jnp.asarray(x))) # jax.Array input. + self.assertEqual(expected, jax.jit(jnp_op)(x)) # Traced input. + + @jtu.sample_product( + shape=nonzerodim_shapes, + dtype=default_dtypes, + ) + def testSizeAlongAxis(self, shape, dtype): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + axis = self.rng().randint(-len(shape), len(shape)) + np_op = partial(np.size, axis=axis) + jnp_op = partial(jnp.size, axis=axis) + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @jtu.sample_product( + op=[jnp.ndim, jnp.shape, jnp.size], + ) + def testNdimShapeSizeNonArrayInput(self, op): + msg = f"{op.__name__} requires ndarray or scalar arguments" + with self.assertWarnsRegex(DeprecationWarning, msg): + op([1, 2, 3]) + # Most grad tests are at the lax level (see lax_test.py), but we add some here # as needed for e.g. particular compound ops of interest.