Improve trace-time performance of jnp.isscalar

This commit is contained in:
Jake VanderPlas 2024-12-03 15:43:33 -08:00
parent 9e5edb7015
commit 0140a98e34

View File

@ -624,9 +624,11 @@ def isscalar(element: Any) -> bool:
>>> jnp.isscalar(slice(10))
False
"""
if (isinstance(element, (np.ndarray, jax.Array))
or hasattr(element, '__jax_array__')
or np.isscalar(element)):
if np.isscalar(element):
return True
elif isinstance(element, (np.ndarray, jax.Array)):
return element.ndim == 0
elif hasattr(element, '__jax_array__'):
return asarray(element).ndim == 0
return False