mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #15481 from mattjj:shmap-scan-rep-rule
PiperOrigin-RevId: 522885050
This commit is contained in:
commit
727f68b952
@ -39,10 +39,10 @@ from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
from jax._src.core import Tracer
|
||||
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
|
||||
windowed_reductions, fft, linalg)
|
||||
windowed_reductions, fft, linalg, control_flow)
|
||||
from jax._src.util import (HashableFunction, HashablePartial, unzip2, unzip3,
|
||||
as_hashable_function, memoize, partition_list,
|
||||
merge_lists)
|
||||
merge_lists, split_list)
|
||||
from jax.api_util import flatten_fun_nokwargs, shaped_abstractify
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
@ -712,6 +712,20 @@ def _pjit_rule(mesh, *in_rep, jaxpr, **kwargs):
|
||||
def _debug_callback_rule(mesh, *in_rep, **_):
|
||||
return []
|
||||
|
||||
@register_rule(control_flow.loops.scan_p)
|
||||
def _scan_rule(mesh, *in_rep, jaxpr, num_consts, num_carry, linear, length,
|
||||
reverse, unroll):
|
||||
const_rep, carry_rep, xs_rep = split_list(in_rep, [num_consts, num_carry])
|
||||
for _ in range(1 + num_carry):
|
||||
out_rep = _output_rep(mesh, jaxpr.jaxpr, [*const_rep, *carry_rep, *xs_rep])
|
||||
if carry_rep == out_rep[:num_carry]:
|
||||
break
|
||||
else:
|
||||
carry_rep = map(op.and_, carry_rep, out_rep[:num_carry])
|
||||
else:
|
||||
assert False, 'Fixpoint not reached'
|
||||
return out_rep
|
||||
|
||||
# Batching
|
||||
|
||||
def _shard_map_batch(
|
||||
@ -831,7 +845,7 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
|
||||
out_names_thunk=known_out_names, check_rep=check_rep)
|
||||
out = shard_map_p.bind(f_known, *in_consts, **known_params)
|
||||
out_knowns, out_avals_sharded, jaxpr, env = aux()
|
||||
out_consts, res = pe.split_list(out, [len(out) - len(jaxpr.constvars)])
|
||||
out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)])
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
|
||||
unk_out_names, _ = pe.partition_list(out_knowns, out_names_thunk())
|
||||
@ -865,7 +879,7 @@ def _shard_map_partial_eval_post_process(
|
||||
|
||||
def todo(out):
|
||||
trace = main.with_cur_sublevel()
|
||||
out_consts, res = pe.split_list(out, [len(out) - len(jaxpr.constvars)])
|
||||
out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)])
|
||||
const_tracers = map(trace.new_instantiated_const, res)
|
||||
env_tracers = map(trace.full_raise, env)
|
||||
|
||||
|
@ -595,6 +595,50 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
shard_map(f, mesh, in_specs=P(), out_specs=P())(3.) # doesn't crash
|
||||
|
||||
def test_scan_rep_rule(self):
|
||||
mesh = jtu.create_global_mesh((2, 2,), ('x', 'y'))
|
||||
|
||||
def f(x, y, z):
|
||||
x, y, z = x.sum(), y.sum(), z.sum()
|
||||
def body(c, _):
|
||||
c, *cs = c
|
||||
return (*cs, c), None
|
||||
out, _ = jax.lax.scan(body, (x, y, z), None, length=3)
|
||||
return [jnp.expand_dims(a, 0) for a in out]
|
||||
|
||||
x = jnp.arange(4)
|
||||
|
||||
# doesn't crash, because out_spec assumes no replication (and there is none)
|
||||
shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))),
|
||||
out_specs=P(('x', 'y')))(x, x, x)
|
||||
|
||||
# does crash, because output incorrectly promises replication
|
||||
with self.assertRaisesRegex(ValueError, "require replication"):
|
||||
shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))),
|
||||
out_specs=P('x'))(x, x, x)
|
||||
with self.assertRaisesRegex(ValueError, "require replication"):
|
||||
shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))),
|
||||
out_specs=P('y'))(x, x, x)
|
||||
with self.assertRaisesRegex(ValueError, "require replication"):
|
||||
shard_map(f, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))),
|
||||
out_specs=P(None))(x, x, x)
|
||||
|
||||
def g(x, y, z):
|
||||
x, y, z = x.sum(), y.sum(), z.sum()
|
||||
def body(c, _):
|
||||
return c, None
|
||||
out, _ = jax.lax.scan(body, (x, y, z), None, length=1)
|
||||
return [jnp.expand_dims(a, 0) for a in out]
|
||||
|
||||
# doesn't crash, because everything matches
|
||||
shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))),
|
||||
out_specs=[P(None), P('x'), P(('x', 'y'))])(x, x, x)
|
||||
|
||||
# does crash, because the second guy is wrong
|
||||
with self.assertRaisesRegex(ValueError, "require replication"):
|
||||
shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))),
|
||||
out_specs=[P(None), P(None), P(('x', 'y'))])(x, x, x)
|
||||
|
||||
|
||||
class FunSpec(NamedTuple):
|
||||
name: str
|
||||
|
Loading…
x
Reference in New Issue
Block a user