mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Order axis names for shard_map residuals
This commit is contained in:
parent
5e2710c2c2
commit
a4f090819f
@ -1370,7 +1370,7 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
|
||||
in_fwd, out_fwd, out_knowns, _, jaxpr, _ = aux()
|
||||
_, out_known_names = pe.partition_list(out_knowns, out_names_thunk())
|
||||
num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
|
||||
return (*out_known_names, *({0: (*all_names,)},) * num_res)
|
||||
return (*out_known_names, *({0: all_names},) * num_res)
|
||||
|
||||
known_params = dict(mesh=mesh, in_names=(*known_in_names,),
|
||||
out_names_thunk=known_out_names, check_rep=check_rep,
|
||||
@ -1385,7 +1385,7 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
|
||||
res = subs_list2(in_fwd, out_fwd, in_consts, out_consts, non_fwd_res)
|
||||
res_names = [known_in_names[f1] if f1 is not None else
|
||||
known_out_names_[f2] if f2 is not None else
|
||||
{0: (*all_names,)} for f1, f2 in zip(in_fwd, out_fwd)]
|
||||
{0: all_names} for f1, f2 in zip(in_fwd, out_fwd)]
|
||||
unk_in_names = (*res_names,) + ({},) * len(env) + (*unk_in_names,)
|
||||
const_tracers = map(trace.new_instantiated_const, res)
|
||||
env_tracers = map(trace.full_raise, env)
|
||||
@ -1428,7 +1428,7 @@ def _shard_map_partial_eval_post_process(
|
||||
const_tracers = map(trace.new_instantiated_const, res_)
|
||||
env_tracers = map(trace.full_raise, env)
|
||||
|
||||
staged_in_names = ({0: (*all_names,)},) * len(res_) + ({},) * len(env)
|
||||
staged_in_names = ({0: all_names},) * len(res_) + ({},) * len(env)
|
||||
staged_params = dict(jaxpr=jaxpr_, mesh=mesh, in_names=staged_in_names,
|
||||
out_names=(*out_names_unknown,), check_rep=False,
|
||||
rewrite=rewrite, auto=auto)
|
||||
@ -1447,7 +1447,7 @@ def _shard_map_partial_eval_post_process(
|
||||
def out_names_transform(out_names):
|
||||
nonlocal out_names_unknown
|
||||
out_names_unknown, out_names_known = partition_list(out_knowns, out_names)
|
||||
return (*out_names_known,) + ({0: (*all_names,)},) * len(res)
|
||||
return (*out_names_known,) + ({0: all_names},) * len(res)
|
||||
out_names_unknown: list | None = None
|
||||
|
||||
return out, (todo, out_names_transform)
|
||||
@ -1560,7 +1560,7 @@ def _partial_eval_jaxpr_custom_rule(
|
||||
params_known, params_staged, all_names = _pe_custom_params(
|
||||
unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd, which,
|
||||
dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged))
|
||||
residuals = [newvar(_unshard_aval(mesh, {0: (*all_names,)}, var.aval))
|
||||
residuals = [newvar(_unshard_aval(mesh, {0: all_names}, var.aval))
|
||||
for var, w in zip(jaxpr_staged.invars[:num_res], which) if w]
|
||||
eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
|
||||
eqn.primitive, params_known, jaxpr_known.effects,
|
||||
@ -1612,7 +1612,7 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
|
||||
all_names = _all_mesh_names(mesh)
|
||||
in_names_known, _ = partition_list(unks_in, params_known['in_names'])
|
||||
_, out_names_known = partition_list(kept_outs_known, params_known['out_names'])
|
||||
out_names_known = out_names_known + [{0: (*all_names,)}] * sum(which)
|
||||
out_names_known = out_names_known + [{0: all_names}] * sum(which)
|
||||
new_params_known = dict(params_known, in_names=tuple(in_names_known),
|
||||
out_names=tuple(out_names_known))
|
||||
|
||||
@ -1620,7 +1620,7 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
|
||||
_, in_names_staged = partition_list(inst_in, params_staged['in_names'])
|
||||
res_names = [in_names_known[f1] if f1 is not None else
|
||||
out_names_known[f2] if f2 is not None else
|
||||
{0: (*all_names,)} for f1, f2 in zip(in_fwd, out_fwd)]
|
||||
{0: all_names} for f1, f2 in zip(in_fwd, out_fwd)]
|
||||
in_names_staged = res_names + in_names_staged
|
||||
_, out_names_staged = partition_list(kept_outs_staged, params_staged['out_names'])
|
||||
new_params_staged = dict(params_staged, in_names=tuple(in_names_staged),
|
||||
@ -1629,12 +1629,12 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
|
||||
|
||||
|
||||
# TODO(mattjj): remove this mechanism when we revise mesh scopes
|
||||
def _all_mesh_names(mesh: Mesh) -> set[AxisName]:
|
||||
def _all_mesh_names(mesh: Mesh) -> tuple[AxisName, ...]:
|
||||
stack = core.thread_local_state.trace_state.trace_stack.stack
|
||||
names = {n for frame in stack
|
||||
if (ns := frame.payload.get('spmd_axis_name', ())) is not None
|
||||
for n in ns}
|
||||
return set(mesh.axis_names) - names
|
||||
return tuple(name for name in mesh.axis_names if name not in names)
|
||||
|
||||
|
||||
# DCE
|
||||
|
@ -1777,6 +1777,26 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
jax.vmap(jax.grad(lambda x: f(x).sum()), spmd_axis_name='i')(xs) # don't crash
|
||||
|
||||
def test_grad_shmap_residuals_axis_names_in_mesh_order(self):
|
||||
# https://github.com/google/jax/issues/21236
|
||||
mesh = jtu.create_global_mesh((4, 2, 1, 1), ('i', 'j', 'k', 'a'))
|
||||
|
||||
@partial(
|
||||
shard_map,
|
||||
mesh=mesh,
|
||||
in_specs=P('j'),
|
||||
out_specs=P('j'),
|
||||
)
|
||||
def f(x):
|
||||
return jnp.sin(x)
|
||||
|
||||
xs = jnp.arange(16.)
|
||||
|
||||
ir = jax.jit(jax.grad(lambda x: f(x).sum())).lower(xs)
|
||||
self.assertIn(
|
||||
'{jax.result_info = "[(\'i\', \'j\', \'k\', \'a\')]"}',
|
||||
ir.as_text()
|
||||
)
|
||||
|
||||
class FunSpec(NamedTuple):
|
||||
name: str
|
||||
|
Loading…
x
Reference in New Issue
Block a user