Merge pull request #15481 from mattjj:shmap-scan-rep-rule

PiperOrigin-RevId: 522885050
This commit is contained in:
jax authors 2023-04-08 20:58:28 -07:00
commit 727f68b952
2 changed files with 62 additions and 4 deletions

View File

@ -39,10 +39,10 @@ from jax._src import traceback_util
from jax._src import util
from jax._src.core import Tracer
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
windowed_reductions, fft, linalg)
windowed_reductions, fft, linalg, control_flow)
from jax._src.util import (HashableFunction, HashablePartial, unzip2, unzip3,
as_hashable_function, memoize, partition_list,
merge_lists)
merge_lists, split_list)
from jax.api_util import flatten_fun_nokwargs, shaped_abstractify
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@ -712,6 +712,20 @@ def _pjit_rule(mesh, *in_rep, jaxpr, **kwargs):
def _debug_callback_rule(mesh, *in_rep, **_):
return []
@register_rule(control_flow.loops.scan_p)
def _scan_rule(mesh, *in_rep, jaxpr, num_consts, num_carry, linear, length,
reverse, unroll):
const_rep, carry_rep, xs_rep = split_list(in_rep, [num_consts, num_carry])
for _ in range(1 + num_carry):
out_rep = _output_rep(mesh, jaxpr.jaxpr, [*const_rep, *carry_rep, *xs_rep])
if carry_rep == out_rep[:num_carry]:
break
else:
carry_rep = map(op.and_, carry_rep, out_rep[:num_carry])
else:
assert False, 'Fixpoint not reached'
return out_rep
# Batching
def _shard_map_batch(
@ -831,7 +845,7 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
out_names_thunk=known_out_names, check_rep=check_rep)
out = shard_map_p.bind(f_known, *in_consts, **known_params)
out_knowns, out_avals_sharded, jaxpr, env = aux()
out_consts, res = pe.split_list(out, [len(out) - len(jaxpr.constvars)])
out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)])
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
unk_out_names, _ = pe.partition_list(out_knowns, out_names_thunk())
@ -865,7 +879,7 @@ def _shard_map_partial_eval_post_process(
def todo(out):
trace = main.with_cur_sublevel()
out_consts, res = pe.split_list(out, [len(out) - len(jaxpr.constvars)])
out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)])
const_tracers = map(trace.new_instantiated_const, res)
env_tracers = map(trace.full_raise, env)

View File

@ -595,6 +595,50 @@ class ShardMapTest(jtu.JaxTestCase):
shard_map(f, mesh, in_specs=P(), out_specs=P())(3.) # doesn't crash
def test_scan_rep_rule(self):
mesh = jtu.create_global_mesh((2, 2,), ('x', 'y'))
def f(x, y, z):
x, y, z = x.sum(), y.sum(), z.sum()
def body(c, _):
c, *cs = c
return (*cs, c), None
out, _ = jax.lax.scan(body, (x, y, z), None, length=3)
return [jnp.expand_dims(a, 0) for a in out]
x = jnp.arange(4)
# doesn't crash, because out_spec assumes no replication (and there is none)
shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))),
out_specs=P(('x', 'y')))(x, x, x)
# does crash, because output incorrectly promises replication
with self.assertRaisesRegex(ValueError, "require replication"):
shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))),
out_specs=P('x'))(x, x, x)
with self.assertRaisesRegex(ValueError, "require replication"):
shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))),
out_specs=P('y'))(x, x, x)
with self.assertRaisesRegex(ValueError, "require replication"):
shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))),
out_specs=P(None))(x, x, x)
def g(x, y, z):
x, y, z = x.sum(), y.sum(), z.sum()
def body(c, _):
return c, None
out, _ = jax.lax.scan(body, (x, y, z), None, length=1)
return [jnp.expand_dims(a, 0) for a in out]
# doesn't crash, because everything matches
shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))),
out_specs=[P(None), P('x'), P(('x', 'y'))])(x, x, x)
# does crash, because the second guy is wrong
with self.assertRaisesRegex(ValueError, "require replication"):
shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))),
out_specs=[P(None), P(None), P(('x', 'y'))])(x, x, x)
class FunSpec(NamedTuple):
name: str