mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
relax tanh test tols for upcoming xla change
This commit is contained in:
parent
39b14b1b1f
commit
6ba0ef6505
@ -281,7 +281,8 @@ class JetTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_cosh(self): self.unary_check(jnp.cosh)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_tanh(self): self.unary_check(jnp.tanh, lims=[-500, 500], order=5)
|
||||
def test_tanh(self): self.unary_check(jnp.tanh, lims=[-500, 500], order=5,
|
||||
atol=5e-3)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_logistic(self): self.unary_check(lax.logistic, lims=[-100, 100], order=5)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
|
@ -170,7 +170,7 @@ LAX_GRAD_SPECIAL_VALUE_TESTS = [
|
||||
grad_special_values_test_spec(
|
||||
lax.cosh, [0.],
|
||||
tol={np.float32: 1e-2} if jtu.device_under_test() == "tpu" else None),
|
||||
grad_special_values_test_spec(lax.tanh, [0., 1000.]),
|
||||
grad_special_values_test_spec(lax.tanh, [0., 1000.], tol=5e-3),
|
||||
grad_special_values_test_spec(lax.sin, [0., np.pi, np.pi/2., np.pi/4.]),
|
||||
grad_special_values_test_spec(lax.cos, [0., np.pi, np.pi/2., np.pi/4.]),
|
||||
grad_special_values_test_spec(lax.tan, [0.]),
|
||||
|
Loading…
x
Reference in New Issue
Block a user