parallelization rule for lax.select

This commit is contained in:
Roy Frostig 2019-04-03 15:13:04 -07:00
parent 2cec9f97d5
commit 794af8bd55
2 changed files with 37 additions and 0 deletions

View File

@ -2637,6 +2637,19 @@ def _select_batch_rule(batched_args, batch_dims, **unused_kwargs):
pred = broadcast_in_dim(pred, on_true.shape, [0])
return select(pred, on_true, on_false), 0
def _select_papply_rule(name, vals, dims):
dimset = set([d for d in dims if d is not None])
if len(dimset) != 1:
raise NotImplementedError(
'papply of select with operands split along different dimensions')
like_val, like_dim = [(v, d) for v, d in zip(vals, dims) if d is not None][0]
def normalize_split(val, dim):
return psplit_like(val, like_val, name) if dim is None else val
vals = [normalize_split(v, d) for v, d in zip(vals, dims)]
return select_p.bind(*vals), like_dim
select_p = standard_primitive(_select_shape_rule, _select_dtype_rule, 'select')
ad.defjvp(select_p,
None,
@ -2644,6 +2657,8 @@ ad.defjvp(select_p,
lambda g, b, x, y: select(b, _zeros(g), g))
ad.primitive_transposes[select_p] = _select_transpose_rule
batching.primitive_batchers[select_p] = _select_batch_rule
parallel.papply_primitive_rules[select_p] = _select_papply_rule
def _slice_shape_rule(operand, start_indices, limit_indices, strides,
operand_shape):

View File

@ -123,6 +123,28 @@ class PapplyTest(jtu.JaxTestCase):
expected = onp.max(arg, axis=0)
self.assertAllClose(ans, expected, check_dtypes=False)
def testSelect(self):
pfun, axis_name = papply(lax.select, 5,
in_axes=(None, 0, None))
p = onp.arange(15).reshape((5, 3)) % 4 == 1
t = onp.ones((5, 3))
f = onp.zeros((5, 3))
jaxpr = make_jaxpr(pfun)(p, t[0], f)
def expected_spmd(p, t, f):
return lax.select(
lax.psplit_like(p, t, axis_name),
t,
lax.psplit_like(f, t, axis_name))
expected_jaxpr = make_jaxpr(expected_spmd)(p, t[0], f)
assert repr(jaxpr) == repr(expected_jaxpr)
ans = serial_pmap(pfun, axis_name, in_axes=(None, 0, None))(p, t, f)
expected = lax.select(p, t, f)
self.assertAllClose(ans, expected, check_dtypes=True)
@skip
def DISABLED_testLogSoftmax(self):