mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #7943 from rsepassi:main
PiperOrigin-RevId: 397871285
This commit is contained in:
commit
2d03bde7e5
@ -301,7 +301,7 @@ def _result_dtype(op, *args):
|
||||
|
||||
|
||||
def _arraylike(x):
|
||||
return isinstance(x, ndarray) or isscalar(x) or hasattr(x, '__jax_array__')
|
||||
return isinstance(x, ndarray) or hasattr(x, '__jax_array__') or isscalar(x)
|
||||
|
||||
def _check_arraylike(fun_name, *args):
|
||||
"""Check if all args fit JAX's definition of arraylike."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user