Merge pull request #11833 from jakevdp:ufunc-tests

PiperOrigin-RevId: 466838445
This commit is contained in:
jax authors 2022-08-10 19:02:58 -07:00
commit 4ecd4db27f

View File

@ -6502,8 +6502,7 @@ class NumpyUfuncTests(jtu.JaxTestCase):
for arg_dtypes in jtu.cases_from_list(_dtypes_for_ufunc(name)))
def testUfuncInputTypes(self, name, arg_dtypes):
# TODO(jakevdp): fix following failures and remove from this exception list.
if (name in ['divmod', 'floor_divide', 'fmod', 'gcd', 'left_shift', 'mod',
'power', 'remainder', 'right_shift', 'rint', 'square']
if (name in ['gcd', 'left_shift', 'power', 'remainder', 'right_shift', 'rint', 'square']
and 'bool_' in arg_dtypes):
self.skipTest(f"jax.numpy does not support {name}{tuple(arg_dtypes)}")
if name == 'arctanh' and jnp.issubdtype(arg_dtypes[0], jnp.complexfloating):