Merge pull request #17882 from jakevdp:fix-nightly

PiperOrigin-RevId: 570049158
This commit is contained in:
jax authors 2023-10-02 06:37:02 -07:00
commit 0fe420e1ef

View File

@ -5412,15 +5412,16 @@ class NumpySignaturesTest(jtu.JaxTestCase):
_available_numpy_dtypes: list[str] = [dtype.__name__ for dtype in jtu.dtypes.all
if dtype != dtypes.bfloat16]
# TODO(jakevdp): implement missing ufuncs
UNIMPLEMENTED_UFUNCS = {'spacing', 'bitwise_count'}
def _all_numpy_ufuncs() -> Iterator[str]:
"""Generate the names of all ufuncs in the top-level numpy namespace."""
for name in dir(np):
f = getattr(np, name)
if isinstance(f, np.ufunc):
# jnp.spacing is not implemented.
if f.__name__ != "spacing":
yield name
if isinstance(f, np.ufunc) and name not in UNIMPLEMENTED_UFUNCS:
yield name
def _dtypes_for_ufunc(name: str) -> Iterator[tuple[str, ...]]: