relax tanh test tols for upcoming xla change

This commit is contained in:
Matthew Johnson 2022-12-20 21:06:09 -08:00
parent 39b14b1b1f
commit 6ba0ef6505
2 changed files with 3 additions and 2 deletions

View File

@ -281,7 +281,8 @@ class JetTest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu")
def test_cosh(self): self.unary_check(jnp.cosh)
@jtu.skip_on_devices("tpu")
def test_tanh(self): self.unary_check(jnp.tanh, lims=[-500, 500], order=5)
def test_tanh(self): self.unary_check(jnp.tanh, lims=[-500, 500], order=5,
atol=5e-3)
@jtu.skip_on_devices("tpu")
def test_logistic(self): self.unary_check(lax.logistic, lims=[-100, 100], order=5)
@jtu.skip_on_devices("tpu")

View File

@ -170,7 +170,7 @@ LAX_GRAD_SPECIAL_VALUE_TESTS = [
grad_special_values_test_spec(
lax.cosh, [0.],
tol={np.float32: 1e-2} if jtu.device_under_test() == "tpu" else None),
grad_special_values_test_spec(lax.tanh, [0., 1000.]),
grad_special_values_test_spec(lax.tanh, [0., 1000.], tol=5e-3),
grad_special_values_test_spec(lax.sin, [0., np.pi, np.pi/2., np.pi/4.]),
grad_special_values_test_spec(lax.cos, [0., np.pi, np.pi/2., np.pi/4.]),
grad_special_values_test_spec(lax.tan, [0.]),