From 8cec6e636ad8de654830b52c123d9f6c97cc69b1 Mon Sep 17 00:00:00 2001
From: Jake VanderPlas <jakevdp@google.com>
Date: Thu, 20 Feb 2025 12:31:31 -0800
Subject: [PATCH] jax.numpy ndim/shape/size: deprecate non-array input

---
 CHANGELOG.md            |   2 +
 jax/_src/numpy/util.py  | 125 +++++++++++++++++++++++++++++++++++++++-
 jax/numpy/__init__.py   |   9 ++-
 jax/numpy/__init__.pyi  |   6 +-
 tests/lax_numpy_test.py |  36 ++++++++++++
 5 files changed, 171 insertions(+), 7 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index d5a01b780..86d0cab0f 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -54,6 +54,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
     A downstream effect of this several other internal functions need debug
     info. This change does not affect public APIs.
     See https://github.com/jax-ml/jax/issues/26480 for more detail.
+  * In {func}`jax.numpy.ndim`, {func}`jax.numpy.shape`, and {func}`jax.numpy.size`,
+    non-arraylike inputs (such as lists, tuples, etc.) are now deprecated.
 
 * Bug fixes
   * TPU runtime startup and shutdown time should be significantly improved on
diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py
index 1db2e0bde..e281c63ae 100644
--- a/jax/_src/numpy/util.py
+++ b/jax/_src/numpy/util.py
@@ -26,7 +26,7 @@ from jax._src import dtypes
 from jax._src.lax import lax
 from jax._src.lib import xla_client as xc
 from jax._src.sharding_impls import SingleDeviceSharding
-from jax._src.util import safe_zip, safe_map
+from jax._src.util import safe_zip, safe_map, set_module
 from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape
 from jax.sharding import Sharding
 
@@ -35,6 +35,8 @@ import numpy as np
 zip, unsafe_zip = safe_zip, zip
 map, unsafe_map = safe_map, map
 
+export = set_module('jax.numpy')
+
 _dtype = partial(dtypes.dtype, canonicalize=True)
 
 def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]:
@@ -308,3 +310,124 @@ def normalize_device_to_sharding(device: xc.Device | Sharding | None) -> Shardin
     return SingleDeviceSharding(device)
   else:
     return device
+
+
+@export
+def ndim(a: ArrayLike) -> int:
+  """Return the number of dimensions of an array.
+
+  JAX implementation of :func:`numpy.ndim`. Unlike ``np.ndim``, this function
+  raises a :class:`TypeError` if the input is a collection such as a list or
+  tuple.
+
+  Args:
+    a: array-like object.
+
+  Returns:
+    An integer specifying the number of dimensions of ``a``.
+
+  Examples:
+    Number of dimensions for arrays:
+
+    >>> x = jnp.arange(10)
+    >>> jnp.ndim(x)
+    1
+    >>> y = jnp.ones((2, 3))
+    >>> jnp.ndim(y)
+    2
+
+    This also works for scalars:
+
+    >>> jnp.ndim(3.14)
+    0
+
+    For arrays, this can also be accessed via the :attr:`jax.Array.ndim` property:
+
+    >>> x.ndim
+    1
+  """
+  # Deprecation warning added 2025-2-20.
+  check_arraylike("ndim", a, emit_warning=True)
+  return np.ndim(a)  # NumPy dispatches to a.ndim if available.
+
+
+@export
+def shape(a: ArrayLike) -> tuple[int, ...]:
+  """Return the shape an array.
+
+  JAX implementation of :func:`numpy.shape`. Unlike ``np.shape``, this function
+  raises a :class:`TypeError` if the input is a collection such as a list or
+  tuple.
+
+  Args:
+    a: array-like object.
+
+  Returns:
+    An tuple of integers representing the shape of ``a``.
+
+  Examples:
+    Shape for arrays:
+
+    >>> x = jnp.arange(10)
+    >>> jnp.shape(x)
+    (10,)
+    >>> y = jnp.ones((2, 3))
+    >>> jnp.shape(y)
+    (2, 3)
+
+    This also works for scalars:
+
+    >>> jnp.shape(3.14)
+    ()
+
+    For arrays, this can also be accessed via the :attr:`jax.Array.shape` property:
+
+    >>> x.shape
+    (10,)
+  """
+  # Deprecation warning added 2025-2-20.
+  check_arraylike("shape", a, emit_warning=True)
+  return np.shape(a)  # NumPy dispatches to a.shape if available.
+
+
+@export
+def size(a: ArrayLike, axis: int | None = None) -> int:
+  """Return number of elements along a given axis.
+
+  JAX implementation of :func:`numpy.size`. Unlike ``np.size``, this function
+  raises a :class:`TypeError` if the input is a collection such as a list or
+  tuple.
+
+  Args:
+    a: array-like object
+    axis: optional integer along which to count elements. By default, return
+      the total number of elements.
+
+  Returns:
+    An integer specifying the number of elements in ``a``.
+
+  Examples:
+    Size for arrays:
+
+    >>> x = jnp.arange(10)
+    >>> jnp.size(x)
+    10
+    >>> y = jnp.ones((2, 3))
+    >>> jnp.size(y)
+    6
+    >>> jnp.size(y, axis=1)
+    3
+
+    This also works for scalars:
+
+    >>> jnp.size(3.14)
+    1
+
+    For arrays, this can also be accessed via the :attr:`jax.Array.size` property:
+
+    >>> y.size
+    6
+  """
+  # Deprecation warning added 2025-2-20.
+  check_arraylike("size", a, emit_warning=True)
+  return np.size(a, axis=axis)  # NumPy dispatches to a.size if available.
diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py
index d563483a2..ad71b9f74 100644
--- a/jax/numpy/__init__.py
+++ b/jax/numpy/__init__.py
@@ -254,6 +254,12 @@ from jax._src.numpy.tensor_contractions import (
   vdot as vdot,
 )
 
+from jax._src.numpy.util import (
+  ndim as ndim,
+  shape as shape,
+  size as size,
+)
+
 from jax._src.numpy.window_functions import (
     bartlett as bartlett,
     blackman as blackman,
@@ -279,15 +285,12 @@ from numpy import (
     integer as integer,
     iterable as iterable,
     nan as nan,
-    ndim as ndim,
     newaxis as newaxis,
     number as number,
     object_ as object_,
     pi as pi,
     save as save,
     savez as savez,
-    shape as shape,
-    size as size,
     signedinteger as signedinteger,
     unsignedinteger as unsignedinteger,
 )
diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi
index dee61c145..b73a3b95b 100644
--- a/jax/numpy/__init__.pyi
+++ b/jax/numpy/__init__.pyi
@@ -728,7 +728,7 @@ def nanvar(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
            ddof: int = 0, keepdims: builtins.bool = False,
            where: ArrayLike | None = ...) -> Array: ...
 ndarray = Array
-ndim = _np.ndim
+def ndim(a: ArrayLike) -> int: ...
 def negative(x: ArrayLike, /) -> Array: ...
 newaxis = None
 def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: ...
@@ -842,7 +842,7 @@ def setdiff1d(
     fill_value: ArrayLike | None = ...,
 ) -> Array: ...
 def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = ...) -> Array: ...
-shape = _np.shape
+def shape(a: ArrayLike) -> tuple[int, ...]: ...
 def sign(x: ArrayLike, /) -> Array: ...
 def signbit(x: ArrayLike, /) -> Array: ...
 signedinteger = _np.signedinteger
@@ -850,7 +850,7 @@ def sin(x: ArrayLike, /) -> Array: ...
 def sinc(x: ArrayLike, /) -> Array: ...
 single: Any
 def sinh(x: ArrayLike, /) -> Array: ...
-size = _np.size
+def size(a: ArrayLike, axis: int | None = None) -> int: ...
 def sort(
     a: ArrayLike,
     axis: int | None = ...,
diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py
index 50773e23b..98f10d9c0 100644
--- a/tests/lax_numpy_test.py
+++ b/tests/lax_numpy_test.py
@@ -6140,6 +6140,42 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
     self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol,
                           check_dtypes=False)
 
+  @jtu.sample_product(
+      shape=all_shapes,
+      dtype=default_dtypes,
+      op=['ndim', 'shape', 'size'],
+  )
+  def testNdimShapeSize(self, shape, dtype, op):
+    rng = jtu.rand_default(self.rng())
+    jnp_op = getattr(jnp, op)
+    np_op = getattr(np, op)
+    x = rng(shape, dtype)
+    expected = np_op(x)
+    self.assertEqual(expected, jnp_op(x))  # np.ndarray or scalar input.
+    self.assertEqual(expected, jnp_op(jnp.asarray(x)))  # jax.Array input.
+    self.assertEqual(expected, jax.jit(jnp_op)(x))  # Traced input.
+
+  @jtu.sample_product(
+      shape=nonzerodim_shapes,
+      dtype=default_dtypes,
+  )
+  def testSizeAlongAxis(self, shape, dtype):
+    rng = jtu.rand_default(self.rng())
+    args_maker = lambda: [rng(shape, dtype)]
+    axis = self.rng().randint(-len(shape), len(shape))
+    np_op = partial(np.size, axis=axis)
+    jnp_op = partial(jnp.size, axis=axis)
+    self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
+    self._CompileAndCheck(jnp_op, args_maker)
+
+  @jtu.sample_product(
+      op=[jnp.ndim, jnp.shape, jnp.size],
+  )
+  def testNdimShapeSizeNonArrayInput(self, op):
+    msg = f"{op.__name__} requires ndarray or scalar arguments"
+    with self.assertWarnsRegex(DeprecationWarning, msg):
+      op([1, 2, 3])
+
 
 # Most grad tests are at the lax level (see lax_test.py), but we add some here
 # as needed for e.g. particular compound ops of interest.