diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index cda162f29..f63c9b841 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1370,7 +1370,7 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names, in_fwd, out_fwd, out_knowns, _, jaxpr, _ = aux() _, out_known_names = pe.partition_list(out_knowns, out_names_thunk()) num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) - return (*out_known_names, *({0: (*all_names,)},) * num_res) + return (*out_known_names, *({0: all_names},) * num_res) known_params = dict(mesh=mesh, in_names=(*known_in_names,), out_names_thunk=known_out_names, check_rep=check_rep, @@ -1385,7 +1385,7 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names, res = subs_list2(in_fwd, out_fwd, in_consts, out_consts, non_fwd_res) res_names = [known_in_names[f1] if f1 is not None else known_out_names_[f2] if f2 is not None else - {0: (*all_names,)} for f1, f2 in zip(in_fwd, out_fwd)] + {0: all_names} for f1, f2 in zip(in_fwd, out_fwd)] unk_in_names = (*res_names,) + ({},) * len(env) + (*unk_in_names,) const_tracers = map(trace.new_instantiated_const, res) env_tracers = map(trace.full_raise, env) @@ -1428,7 +1428,7 @@ def _shard_map_partial_eval_post_process( const_tracers = map(trace.new_instantiated_const, res_) env_tracers = map(trace.full_raise, env) - staged_in_names = ({0: (*all_names,)},) * len(res_) + ({},) * len(env) + staged_in_names = ({0: all_names},) * len(res_) + ({},) * len(env) staged_params = dict(jaxpr=jaxpr_, mesh=mesh, in_names=staged_in_names, out_names=(*out_names_unknown,), check_rep=False, rewrite=rewrite, auto=auto) @@ -1447,7 +1447,7 @@ def _shard_map_partial_eval_post_process( def out_names_transform(out_names): nonlocal out_names_unknown out_names_unknown, out_names_known = partition_list(out_knowns, out_names) - return (*out_names_known,) + ({0: (*all_names,)},) * len(res) + return (*out_names_known,) + ({0: all_names},) * len(res) out_names_unknown: list | None = None return out, (todo, out_names_transform) @@ -1560,7 +1560,7 @@ def _partial_eval_jaxpr_custom_rule( params_known, params_staged, all_names = _pe_custom_params( unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd, which, dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged)) - residuals = [newvar(_unshard_aval(mesh, {0: (*all_names,)}, var.aval)) + residuals = [newvar(_unshard_aval(mesh, {0: all_names}, var.aval)) for var, w in zip(jaxpr_staged.invars[:num_res], which) if w] eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals], eqn.primitive, params_known, jaxpr_known.effects, @@ -1612,7 +1612,7 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, all_names = _all_mesh_names(mesh) in_names_known, _ = partition_list(unks_in, params_known['in_names']) _, out_names_known = partition_list(kept_outs_known, params_known['out_names']) - out_names_known = out_names_known + [{0: (*all_names,)}] * sum(which) + out_names_known = out_names_known + [{0: all_names}] * sum(which) new_params_known = dict(params_known, in_names=tuple(in_names_known), out_names=tuple(out_names_known)) @@ -1620,7 +1620,7 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, _, in_names_staged = partition_list(inst_in, params_staged['in_names']) res_names = [in_names_known[f1] if f1 is not None else out_names_known[f2] if f2 is not None else - {0: (*all_names,)} for f1, f2 in zip(in_fwd, out_fwd)] + {0: all_names} for f1, f2 in zip(in_fwd, out_fwd)] in_names_staged = res_names + in_names_staged _, out_names_staged = partition_list(kept_outs_staged, params_staged['out_names']) new_params_staged = dict(params_staged, in_names=tuple(in_names_staged), @@ -1629,12 +1629,12 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, # TODO(mattjj): remove this mechanism when we revise mesh scopes -def _all_mesh_names(mesh: Mesh) -> set[AxisName]: +def _all_mesh_names(mesh: Mesh) -> tuple[AxisName, ...]: stack = core.thread_local_state.trace_state.trace_stack.stack names = {n for frame in stack if (ns := frame.payload.get('spmd_axis_name', ())) is not None for n in ns} - return set(mesh.axis_names) - names + return tuple(name for name in mesh.axis_names if name not in names) # DCE diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 5de80d7c8..b02cd17e7 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -1777,6 +1777,26 @@ class ShardMapTest(jtu.JaxTestCase): jax.vmap(jax.grad(lambda x: f(x).sum()), spmd_axis_name='i')(xs) # don't crash + def test_grad_shmap_residuals_axis_names_in_mesh_order(self): + # https://github.com/google/jax/issues/21236 + mesh = jtu.create_global_mesh((4, 2, 1, 1), ('i', 'j', 'k', 'a')) + + @partial( + shard_map, + mesh=mesh, + in_specs=P('j'), + out_specs=P('j'), + ) + def f(x): + return jnp.sin(x) + + xs = jnp.arange(16.) + + ir = jax.jit(jax.grad(lambda x: f(x).sum())).lower(xs) + self.assertIn( + '{jax.result_info = "[(\'i\', \'j\', \'k\', \'a\')]"}', + ir.as_text() + ) class FunSpec(NamedTuple): name: str