mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #6741 from hawkinsp:ascan2
PiperOrigin-RevId: 373422349
This commit is contained in:
commit
db79701732
@ -2471,6 +2471,22 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
y = lax.associative_scan(lax.bitwise_xor, x)
|
||||
self.assertArraysEqual(np.array([False, True, False, True, True, False]), y)
|
||||
|
||||
@parameterized.named_parameters({"testcase_name": f"_{shape}", "shape": shape}
|
||||
for shape in [2, 43, 100])
|
||||
def testAssociativeScanSolvingRegressionTest(self, shape):
|
||||
# This test checks that the batching rule doesn't raise for a batch
|
||||
# sensitive function (solve).
|
||||
ms = np.repeat(np.eye(2).reshape(1, 2, 2), shape, axis=0)
|
||||
vs = np.ones((shape, 2))
|
||||
|
||||
@api.vmap
|
||||
def fn(a, b):
|
||||
m1, v1 = a
|
||||
m2, v2 = b
|
||||
return m1 + m2, jsp.linalg.solve(m1, v2) + jsp.linalg.solve(m2, v1)
|
||||
|
||||
_ = lax.associative_scan(fn, elems=(ms, vs))
|
||||
|
||||
def test_scan_typecheck_param(self):
|
||||
d = jnp.ones(2)
|
||||
def f(c, a):
|
||||
|
Loading…
x
Reference in New Issue
Block a user