mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #17882 from jakevdp:fix-nightly
PiperOrigin-RevId: 570049158
This commit is contained in:
commit
0fe420e1ef
@ -5412,15 +5412,16 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
_available_numpy_dtypes: list[str] = [dtype.__name__ for dtype in jtu.dtypes.all
|
||||
if dtype != dtypes.bfloat16]
|
||||
|
||||
# TODO(jakevdp): implement missing ufuncs
|
||||
UNIMPLEMENTED_UFUNCS = {'spacing', 'bitwise_count'}
|
||||
|
||||
|
||||
def _all_numpy_ufuncs() -> Iterator[str]:
|
||||
"""Generate the names of all ufuncs in the top-level numpy namespace."""
|
||||
for name in dir(np):
|
||||
f = getattr(np, name)
|
||||
if isinstance(f, np.ufunc):
|
||||
# jnp.spacing is not implemented.
|
||||
if f.__name__ != "spacing":
|
||||
yield name
|
||||
if isinstance(f, np.ufunc) and name not in UNIMPLEMENTED_UFUNCS:
|
||||
yield name
|
||||
|
||||
|
||||
def _dtypes_for_ufunc(name: str) -> Iterator[tuple[str, ...]]:
|
||||
|
Loading…
x
Reference in New Issue
Block a user