Merge pull request #25237 from jakevdp:faster-isscalar

PiperOrigin-RevId: 702517550
This commit is contained in:
jax authors 2024-12-03 17:03:10 -08:00
commit 40122f7c03

View File

@ -624,9 +624,11 @@ def isscalar(element: Any) -> bool:
>>> jnp.isscalar(slice(10))
False
"""
if (isinstance(element, (np.ndarray, jax.Array))
or hasattr(element, '__jax_array__')
or np.isscalar(element)):
if np.isscalar(element):
return True
elif isinstance(element, (np.ndarray, jax.Array)):
return element.ndim == 0
elif hasattr(element, '__jax_array__'):
return asarray(element).ndim == 0
return False