diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 4767e48c3..0edc09fa7 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -715,7 +715,7 @@ class JaxNumpyReducerTests(jtu.JaxTestCase): # TODO(phawkins): we currently set dtype=False because we aren't as # aggressive about promoting to float64. It's not clear we want to mimic # Numpy here. - tol_spec = {np.float16: 1E-2, np.float32: 2e-4, np.float64: 5e-6} + tol_spec = {np.float16: 4e-2, np.float32: 2e-4, np.float64: 5e-6} tol = max(jtu.tolerance(a_dtype, tol_spec), jtu.tolerance(q_dtype, tol_spec)) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 860b3358d..f09553c83 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -4768,7 +4768,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): else x.astype(np.float32) for x in choicelist] dtype = jnp.result_type(default, *choicelist) return np.select(condlist, - [np.asarray(x, dtype=dtype) for x in choicelist], + [np.asarray(x).astype(dtype) for x in choicelist], np.asarray(default, dtype=dtype)) with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(np_fun, jnp.select, args_maker, diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index 66d84c427..50d2ee725 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -226,7 +226,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase): rng = jtu.rand_positive(self.rng()) args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, - tol={np.float32: 1e-3, np.float64: 1e-14}) + tol={np.float32: 1e-3, np.float64: 1e-14}, + check_dtypes=False) self._CompileAndCheck( lax_fun, args_maker, rtol={ np.float32: 5e-5 if jtu.test_device_matches(["tpu"]) else 1e-05,