Remove more uses of tan() in reduction tests.

This is to avoid subtly brittle tests. Tan() is an ill-conditioned function to evaluate near it's singularities.
This commit is contained in:
Rasmus Munk Larsen 2023-01-19 15:20:02 -08:00 committed by GitHub
parent b0a7075f66
commit c798fcaefc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1548,7 +1548,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
def f(c, a):
assert a.shape == (3,)
assert c.shape == (4,)
b = jnp.cos(jnp.sum(jnp.sin(a)) + jnp.sum(jnp.cos(c)) + jnp.sum(jnp.tan(d)))
b = jnp.cos(jnp.sum(jnp.sin(a)) + jnp.sum(jnp.cos(c)) + jnp.sum(d))
c = jnp.sin(c * b)
assert b.shape == ()
return c, b
@ -1581,7 +1581,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
def f(c, a):
assert a.shape == (3,)
assert c.shape == (4,)
b = jnp.cos(jnp.sum(jnp.sin(a)) + jnp.sum(jnp.cos(c)) + jnp.sum(jnp.tan(d)))
b = jnp.cos(jnp.sum(jnp.sin(a)) + jnp.sum(jnp.cos(c)) + jnp.sum(d))
c = jnp.sin(c * b)
assert b.shape == ()
return c, b
@ -1821,7 +1821,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
def f(c, a):
assert a.shape == (3,)
assert c.shape == (4,)
b = jnp.cos(jnp.sum(jnp.sin(a)) + jnp.sum(jnp.cos(c)) + jnp.sum(jnp.tan(d)))
b = jnp.cos(jnp.sum(jnp.sin(a)) + jnp.sum(jnp.cos(c)) + jnp.sum(d))
c = jnp.sin(c * b)
assert b.shape == ()
return c, b
@ -1852,7 +1852,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
def f(c, a):
a1, a2 = a
c1, c2 = c
b = jnp.sum(jnp.cos(a1)) * jnp.sum(jnp.tan(c2 * a2))
b = jnp.sum(jnp.cos(a1)) * jnp.sum(c2 * a2)
c = c1 * jnp.sin(jnp.sum(a1 * a2)), c2 * jnp.cos(jnp.sum(a1))
return c, b
@ -2221,7 +2221,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
def f(c, a):
assert a.shape == (3,)
assert c.shape == (4,)
b = jnp.cos(jnp.sum(jnp.sin(a)) + jnp.sum(jnp.cos(c)) + jnp.sum(jnp.tan(d)))
b = jnp.cos(jnp.sum(jnp.sin(a)) + jnp.sum(jnp.cos(c)) + jnp.sum(d))
c = jnp.sin(c * b)
assert b.shape == ()
return c, b