Merge pull request #6738 from hawkinsp:ascan

PiperOrigin-RevId: 373417478
This commit is contained in:
jax authors 2021-05-12 11:47:03 -07:00
commit 8f71d20a8b
2 changed files with 8 additions and 2 deletions

View File

@ -2382,8 +2382,9 @@ def _interleave(a, b, axis):
b_pad = [(0, 0, 0)] * b.ndim
a_pad[axis] = (0, 1 if a.shape[axis] == b.shape[axis] else 0, 1)
b_pad[axis] = (1, 0 if a.shape[axis] == b.shape[axis] else 1, 1)
return lax.add(lax.pad(a, lax._const(a, 0), a_pad),
lax.pad(b, lax._const(b, 0), b_pad))
op = lax.bitwise_or if a.dtype == np.bool_ else lax.add
return op(lax.pad(a, lax._const(a, 0), a_pad),
lax.pad(b, lax._const(b, 0), b_pad))
@api_boundary
def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):

View File

@ -2466,6 +2466,11 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertAllClose(result.second, np.array([0., 10., 30.]),
check_dtypes=False)
def testAssociativeScanOfBools(self):
x = jnp.array([False, True, True, True, False, True])
y = lax.associative_scan(lax.bitwise_xor, x)
self.assertArraysEqual(np.array([False, True, False, True, True, False]), y)
def test_scan_typecheck_param(self):
d = jnp.ones(2)
def f(c, a):