Order axis names for shard_map residuals

This commit is contained in:
Jaroslav Sevcik 2024-05-17 06:21:59 -07:00
parent 5e2710c2c2
commit a4f090819f
2 changed files with 29 additions and 9 deletions

View File

@ -1370,7 +1370,7 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
in_fwd, out_fwd, out_knowns, _, jaxpr, _ = aux()
_, out_known_names = pe.partition_list(out_knowns, out_names_thunk())
num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
return (*out_known_names, *({0: (*all_names,)},) * num_res)
return (*out_known_names, *({0: all_names},) * num_res)
known_params = dict(mesh=mesh, in_names=(*known_in_names,),
out_names_thunk=known_out_names, check_rep=check_rep,
@ -1385,7 +1385,7 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
res = subs_list2(in_fwd, out_fwd, in_consts, out_consts, non_fwd_res)
res_names = [known_in_names[f1] if f1 is not None else
known_out_names_[f2] if f2 is not None else
{0: (*all_names,)} for f1, f2 in zip(in_fwd, out_fwd)]
{0: all_names} for f1, f2 in zip(in_fwd, out_fwd)]
unk_in_names = (*res_names,) + ({},) * len(env) + (*unk_in_names,)
const_tracers = map(trace.new_instantiated_const, res)
env_tracers = map(trace.full_raise, env)
@ -1428,7 +1428,7 @@ def _shard_map_partial_eval_post_process(
const_tracers = map(trace.new_instantiated_const, res_)
env_tracers = map(trace.full_raise, env)
staged_in_names = ({0: (*all_names,)},) * len(res_) + ({},) * len(env)
staged_in_names = ({0: all_names},) * len(res_) + ({},) * len(env)
staged_params = dict(jaxpr=jaxpr_, mesh=mesh, in_names=staged_in_names,
out_names=(*out_names_unknown,), check_rep=False,
rewrite=rewrite, auto=auto)
@ -1447,7 +1447,7 @@ def _shard_map_partial_eval_post_process(
def out_names_transform(out_names):
nonlocal out_names_unknown
out_names_unknown, out_names_known = partition_list(out_knowns, out_names)
return (*out_names_known,) + ({0: (*all_names,)},) * len(res)
return (*out_names_known,) + ({0: all_names},) * len(res)
out_names_unknown: list | None = None
return out, (todo, out_names_transform)
@ -1560,7 +1560,7 @@ def _partial_eval_jaxpr_custom_rule(
params_known, params_staged, all_names = _pe_custom_params(
unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd, which,
dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged))
residuals = [newvar(_unshard_aval(mesh, {0: (*all_names,)}, var.aval))
residuals = [newvar(_unshard_aval(mesh, {0: all_names}, var.aval))
for var, w in zip(jaxpr_staged.invars[:num_res], which) if w]
eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
eqn.primitive, params_known, jaxpr_known.effects,
@ -1612,7 +1612,7 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
all_names = _all_mesh_names(mesh)
in_names_known, _ = partition_list(unks_in, params_known['in_names'])
_, out_names_known = partition_list(kept_outs_known, params_known['out_names'])
out_names_known = out_names_known + [{0: (*all_names,)}] * sum(which)
out_names_known = out_names_known + [{0: all_names}] * sum(which)
new_params_known = dict(params_known, in_names=tuple(in_names_known),
out_names=tuple(out_names_known))
@ -1620,7 +1620,7 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
_, in_names_staged = partition_list(inst_in, params_staged['in_names'])
res_names = [in_names_known[f1] if f1 is not None else
out_names_known[f2] if f2 is not None else
{0: (*all_names,)} for f1, f2 in zip(in_fwd, out_fwd)]
{0: all_names} for f1, f2 in zip(in_fwd, out_fwd)]
in_names_staged = res_names + in_names_staged
_, out_names_staged = partition_list(kept_outs_staged, params_staged['out_names'])
new_params_staged = dict(params_staged, in_names=tuple(in_names_staged),
@ -1629,12 +1629,12 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
# TODO(mattjj): remove this mechanism when we revise mesh scopes
def _all_mesh_names(mesh: Mesh) -> set[AxisName]:
def _all_mesh_names(mesh: Mesh) -> tuple[AxisName, ...]:
stack = core.thread_local_state.trace_state.trace_stack.stack
names = {n for frame in stack
if (ns := frame.payload.get('spmd_axis_name', ())) is not None
for n in ns}
return set(mesh.axis_names) - names
return tuple(name for name in mesh.axis_names if name not in names)
# DCE

View File

@ -1777,6 +1777,26 @@ class ShardMapTest(jtu.JaxTestCase):
jax.vmap(jax.grad(lambda x: f(x).sum()), spmd_axis_name='i')(xs) # don't crash
def test_grad_shmap_residuals_axis_names_in_mesh_order(self):
# https://github.com/google/jax/issues/21236
mesh = jtu.create_global_mesh((4, 2, 1, 1), ('i', 'j', 'k', 'a'))
@partial(
shard_map,
mesh=mesh,
in_specs=P('j'),
out_specs=P('j'),
)
def f(x):
return jnp.sin(x)
xs = jnp.arange(16.)
ir = jax.jit(jax.grad(lambda x: f(x).sum())).lower(xs)
self.assertIn(
'{jax.result_info = "[(\'i\', \'j\', \'k\', \'a\')]"}',
ir.as_text()
)
class FunSpec(NamedTuple):
name: str