mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #6738 from hawkinsp:ascan
PiperOrigin-RevId: 373417478
This commit is contained in:
commit
8f71d20a8b
@ -2382,8 +2382,9 @@ def _interleave(a, b, axis):
|
||||
b_pad = [(0, 0, 0)] * b.ndim
|
||||
a_pad[axis] = (0, 1 if a.shape[axis] == b.shape[axis] else 0, 1)
|
||||
b_pad[axis] = (1, 0 if a.shape[axis] == b.shape[axis] else 1, 1)
|
||||
return lax.add(lax.pad(a, lax._const(a, 0), a_pad),
|
||||
lax.pad(b, lax._const(b, 0), b_pad))
|
||||
op = lax.bitwise_or if a.dtype == np.bool_ else lax.add
|
||||
return op(lax.pad(a, lax._const(a, 0), a_pad),
|
||||
lax.pad(b, lax._const(b, 0), b_pad))
|
||||
|
||||
@api_boundary
|
||||
def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
|
||||
|
@ -2466,6 +2466,11 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(result.second, np.array([0., 10., 30.]),
|
||||
check_dtypes=False)
|
||||
|
||||
def testAssociativeScanOfBools(self):
|
||||
x = jnp.array([False, True, True, True, False, True])
|
||||
y = lax.associative_scan(lax.bitwise_xor, x)
|
||||
self.assertArraysEqual(np.array([False, True, False, True, True, False]), y)
|
||||
|
||||
def test_scan_typecheck_param(self):
|
||||
d = jnp.ones(2)
|
||||
def f(c, a):
|
||||
|
Loading…
x
Reference in New Issue
Block a user