mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parallelization rule for lax.select
This commit is contained in:
parent
2cec9f97d5
commit
794af8bd55
15
jax/lax.py
15
jax/lax.py
@ -2637,6 +2637,19 @@ def _select_batch_rule(batched_args, batch_dims, **unused_kwargs):
|
||||
pred = broadcast_in_dim(pred, on_true.shape, [0])
|
||||
return select(pred, on_true, on_false), 0
|
||||
|
||||
def _select_papply_rule(name, vals, dims):
|
||||
dimset = set([d for d in dims if d is not None])
|
||||
if len(dimset) != 1:
|
||||
raise NotImplementedError(
|
||||
'papply of select with operands split along different dimensions')
|
||||
like_val, like_dim = [(v, d) for v, d in zip(vals, dims) if d is not None][0]
|
||||
|
||||
def normalize_split(val, dim):
|
||||
return psplit_like(val, like_val, name) if dim is None else val
|
||||
|
||||
vals = [normalize_split(v, d) for v, d in zip(vals, dims)]
|
||||
return select_p.bind(*vals), like_dim
|
||||
|
||||
select_p = standard_primitive(_select_shape_rule, _select_dtype_rule, 'select')
|
||||
ad.defjvp(select_p,
|
||||
None,
|
||||
@ -2644,6 +2657,8 @@ ad.defjvp(select_p,
|
||||
lambda g, b, x, y: select(b, _zeros(g), g))
|
||||
ad.primitive_transposes[select_p] = _select_transpose_rule
|
||||
batching.primitive_batchers[select_p] = _select_batch_rule
|
||||
parallel.papply_primitive_rules[select_p] = _select_papply_rule
|
||||
|
||||
|
||||
def _slice_shape_rule(operand, start_indices, limit_indices, strides,
|
||||
operand_shape):
|
||||
|
@ -123,6 +123,28 @@ class PapplyTest(jtu.JaxTestCase):
|
||||
expected = onp.max(arg, axis=0)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testSelect(self):
|
||||
pfun, axis_name = papply(lax.select, 5,
|
||||
in_axes=(None, 0, None))
|
||||
|
||||
p = onp.arange(15).reshape((5, 3)) % 4 == 1
|
||||
t = onp.ones((5, 3))
|
||||
f = onp.zeros((5, 3))
|
||||
jaxpr = make_jaxpr(pfun)(p, t[0], f)
|
||||
|
||||
def expected_spmd(p, t, f):
|
||||
return lax.select(
|
||||
lax.psplit_like(p, t, axis_name),
|
||||
t,
|
||||
lax.psplit_like(f, t, axis_name))
|
||||
|
||||
expected_jaxpr = make_jaxpr(expected_spmd)(p, t[0], f)
|
||||
assert repr(jaxpr) == repr(expected_jaxpr)
|
||||
|
||||
ans = serial_pmap(pfun, axis_name, in_axes=(None, 0, None))(p, t, f)
|
||||
expected = lax.select(p, t, f)
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
|
||||
@skip
|
||||
def DISABLED_testLogSoftmax(self):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user