mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Tests: reenable some ufunc input tests
This commit is contained in:
parent
e81578a9fa
commit
97c32f67fc
@ -6502,8 +6502,7 @@ class NumpyUfuncTests(jtu.JaxTestCase):
|
||||
for arg_dtypes in jtu.cases_from_list(_dtypes_for_ufunc(name)))
|
||||
def testUfuncInputTypes(self, name, arg_dtypes):
|
||||
# TODO(jakevdp): fix following failures and remove from this exception list.
|
||||
if (name in ['divmod', 'floor_divide', 'fmod', 'gcd', 'left_shift', 'mod',
|
||||
'power', 'remainder', 'right_shift', 'rint', 'square']
|
||||
if (name in ['gcd', 'left_shift', 'power', 'remainder', 'right_shift', 'rint', 'square']
|
||||
and 'bool_' in arg_dtypes):
|
||||
self.skipTest(f"jax.numpy does not support {name}{tuple(arg_dtypes)}")
|
||||
if name == 'arctanh' and jnp.issubdtype(arg_dtypes[0], jnp.complexfloating):
|
||||
|
Loading…
x
Reference in New Issue
Block a user