mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Tests: fix some failures for upstream numpy
This commit is contained in:
parent
41c7cce7c6
commit
95a209f28b
@ -2007,7 +2007,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def testTrapz(self, yshape, xshape, dtype, dx, axis):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(yshape, dtype), rng(xshape, dtype) if xshape is not None else None]
|
||||
np_fun = partial(np.trapz, dx=dx, axis=axis)
|
||||
# TODO(jakevdp): numpy.trapz is removed in numpy 2.0
|
||||
np_fun = jtu.ignore_warning(category=DeprecationWarning)(
|
||||
partial(np.trapz, dx=dx, axis=axis))
|
||||
jnp_fun = partial(jnp.trapz, dx=dx, axis=axis)
|
||||
tol = jtu.tolerance(dtype, {np.float16: 2e-3, np.float64: 1e-12,
|
||||
dtypes.bfloat16: 4e-2})
|
||||
|
@ -2871,7 +2871,7 @@ class LazyConstantTest(jtu.JaxTestCase):
|
||||
if op_name == "bitwise_not":
|
||||
raise unittest.SkipTest("https://github.com/google/jax/issues/12066")
|
||||
# Find a valid dtype for the function.
|
||||
for dtype in [np.float_, np.int_, np.complex_, np.bool_]:
|
||||
for dtype in [float, int, complex, bool]:
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
if dtype in rec_dtypes:
|
||||
py_val = dtype.type(1).item()
|
||||
|
Loading…
x
Reference in New Issue
Block a user