Merge pull request #7943 from rsepassi:main

PiperOrigin-RevId: 397871285
This commit is contained in:
jax authors 2021-09-20 16:42:00 -07:00
commit 2d03bde7e5

View File

@ -301,7 +301,7 @@ def _result_dtype(op, *args):
def _arraylike(x):
return isinstance(x, ndarray) or isscalar(x) or hasattr(x, '__jax_array__')
return isinstance(x, ndarray) or hasattr(x, '__jax_array__') or isscalar(x)
def _check_arraylike(fun_name, *args):
"""Check if all args fit JAX's definition of arraylike."""