Merge pull request #6741 from hawkinsp:ascan2

PiperOrigin-RevId: 373422349
This commit is contained in:
jax authors 2021-05-12 12:10:34 -07:00
commit db79701732

View File

@ -2471,6 +2471,22 @@ class LaxControlFlowTest(jtu.JaxTestCase):
y = lax.associative_scan(lax.bitwise_xor, x)
self.assertArraysEqual(np.array([False, True, False, True, True, False]), y)
@parameterized.named_parameters({"testcase_name": f"_{shape}", "shape": shape}
for shape in [2, 43, 100])
def testAssociativeScanSolvingRegressionTest(self, shape):
# This test checks that the batching rule doesn't raise for a batch
# sensitive function (solve).
ms = np.repeat(np.eye(2).reshape(1, 2, 2), shape, axis=0)
vs = np.ones((shape, 2))
@api.vmap
def fn(a, b):
m1, v1 = a
m2, v2 = b
return m1 + m2, jsp.linalg.solve(m1, v2) + jsp.linalg.solve(m2, v1)
_ = lax.associative_scan(fn, elems=(ms, vs))
def test_scan_typecheck_param(self):
d = jnp.ones(2)
def f(c, a):