mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #25237 from jakevdp:faster-isscalar
PiperOrigin-RevId: 702517550
This commit is contained in:
commit
40122f7c03
@ -624,9 +624,11 @@ def isscalar(element: Any) -> bool:
|
||||
>>> jnp.isscalar(slice(10))
|
||||
False
|
||||
"""
|
||||
if (isinstance(element, (np.ndarray, jax.Array))
|
||||
or hasattr(element, '__jax_array__')
|
||||
or np.isscalar(element)):
|
||||
if np.isscalar(element):
|
||||
return True
|
||||
elif isinstance(element, (np.ndarray, jax.Array)):
|
||||
return element.ndim == 0
|
||||
elif hasattr(element, '__jax_array__'):
|
||||
return asarray(element).ndim == 0
|
||||
return False
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user