mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Remove more uses of tan() in reduction tests.
This is to avoid subtly brittle tests. Tan() is an ill-conditioned function to evaluate near it's singularities.
This commit is contained in:
parent
b0a7075f66
commit
c798fcaefc
@ -1548,7 +1548,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
def f(c, a):
|
||||
assert a.shape == (3,)
|
||||
assert c.shape == (4,)
|
||||
b = jnp.cos(jnp.sum(jnp.sin(a)) + jnp.sum(jnp.cos(c)) + jnp.sum(jnp.tan(d)))
|
||||
b = jnp.cos(jnp.sum(jnp.sin(a)) + jnp.sum(jnp.cos(c)) + jnp.sum(d))
|
||||
c = jnp.sin(c * b)
|
||||
assert b.shape == ()
|
||||
return c, b
|
||||
@ -1581,7 +1581,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
def f(c, a):
|
||||
assert a.shape == (3,)
|
||||
assert c.shape == (4,)
|
||||
b = jnp.cos(jnp.sum(jnp.sin(a)) + jnp.sum(jnp.cos(c)) + jnp.sum(jnp.tan(d)))
|
||||
b = jnp.cos(jnp.sum(jnp.sin(a)) + jnp.sum(jnp.cos(c)) + jnp.sum(d))
|
||||
c = jnp.sin(c * b)
|
||||
assert b.shape == ()
|
||||
return c, b
|
||||
@ -1821,7 +1821,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
def f(c, a):
|
||||
assert a.shape == (3,)
|
||||
assert c.shape == (4,)
|
||||
b = jnp.cos(jnp.sum(jnp.sin(a)) + jnp.sum(jnp.cos(c)) + jnp.sum(jnp.tan(d)))
|
||||
b = jnp.cos(jnp.sum(jnp.sin(a)) + jnp.sum(jnp.cos(c)) + jnp.sum(d))
|
||||
c = jnp.sin(c * b)
|
||||
assert b.shape == ()
|
||||
return c, b
|
||||
@ -1852,7 +1852,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
def f(c, a):
|
||||
a1, a2 = a
|
||||
c1, c2 = c
|
||||
b = jnp.sum(jnp.cos(a1)) * jnp.sum(jnp.tan(c2 * a2))
|
||||
b = jnp.sum(jnp.cos(a1)) * jnp.sum(c2 * a2)
|
||||
c = c1 * jnp.sin(jnp.sum(a1 * a2)), c2 * jnp.cos(jnp.sum(a1))
|
||||
return c, b
|
||||
|
||||
@ -2221,7 +2221,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
def f(c, a):
|
||||
assert a.shape == (3,)
|
||||
assert c.shape == (4,)
|
||||
b = jnp.cos(jnp.sum(jnp.sin(a)) + jnp.sum(jnp.cos(c)) + jnp.sum(jnp.tan(d)))
|
||||
b = jnp.cos(jnp.sum(jnp.sin(a)) + jnp.sum(jnp.cos(c)) + jnp.sum(d))
|
||||
c = jnp.sin(c * b)
|
||||
assert b.shape == ()
|
||||
return c, b
|
||||
|
Loading…
x
Reference in New Issue
Block a user