From 0140a98e34786790332e60d1d3b4d8a82d29d896 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 3 Dec 2024 15:43:33 -0800 Subject: [PATCH] Improve trace-time performance of jnp.isscalar --- jax/_src/numpy/lax_numpy.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 6a0e4059c..5af8c6dda 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -624,9 +624,11 @@ def isscalar(element: Any) -> bool: >>> jnp.isscalar(slice(10)) False """ - if (isinstance(element, (np.ndarray, jax.Array)) - or hasattr(element, '__jax_array__') - or np.isscalar(element)): + if np.isscalar(element): + return True + elif isinstance(element, (np.ndarray, jax.Array)): + return element.ndim == 0 + elif hasattr(element, '__jax_array__'): return asarray(element).ndim == 0 return False