Tests: fix some failures for upstream numpy

This commit is contained in:
Jake VanderPlas 2023-09-20 12:26:12 -07:00
parent 41c7cce7c6
commit 95a209f28b
2 changed files with 4 additions and 2 deletions

View File

@ -2007,7 +2007,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def testTrapz(self, yshape, xshape, dtype, dx, axis):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(yshape, dtype), rng(xshape, dtype) if xshape is not None else None]
np_fun = partial(np.trapz, dx=dx, axis=axis)
# TODO(jakevdp): numpy.trapz is removed in numpy 2.0
np_fun = jtu.ignore_warning(category=DeprecationWarning)(
partial(np.trapz, dx=dx, axis=axis))
jnp_fun = partial(jnp.trapz, dx=dx, axis=axis)
tol = jtu.tolerance(dtype, {np.float16: 2e-3, np.float64: 1e-12,
dtypes.bfloat16: 4e-2})

View File

@ -2871,7 +2871,7 @@ class LazyConstantTest(jtu.JaxTestCase):
if op_name == "bitwise_not":
raise unittest.SkipTest("https://github.com/google/jax/issues/12066")
# Find a valid dtype for the function.
for dtype in [np.float_, np.int_, np.complex_, np.bool_]:
for dtype in [float, int, complex, bool]:
dtype = dtypes.canonicalize_dtype(dtype)
if dtype in rec_dtypes:
py_val = dtype.type(1).item()