Only add new manual axes to residuals when adding axes with partial_auto.

PiperOrigin-RevId: 728839349
This commit is contained in:
Parker Schuh 2025-02-19 15:27:00 -08:00 committed by jax authors
parent dbb46e9214
commit b7c66bd22e
2 changed files with 37 additions and 6 deletions

View File

@ -1525,7 +1525,7 @@ def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p,
in_pvals = [t.pval for t in tracers]
in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals)
unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names)
all_names = _all_mesh_names_except_spmd(mesh, auto, trace)
all_names = _all_newly_manual_mesh_names(mesh, auto, trace)
in_avals_sharded = map(partial(_shard_aval, mesh), unk_in_names, in_avals)
f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.tag, f.debug_info, False)
f = _promote_scalar_residuals(f)
@ -1542,8 +1542,7 @@ def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p,
known_params = dict(mesh=mesh, in_names=(*known_in_names,),
out_names_thunk=known_out_names, check_rep=check_rep,
rewrite=rewrite, auto=auto)
with _extend_axis_env(mesh, auto):
out = shard_map_p.bind_with_trace(trace.parent_trace, (f_known, *in_consts), known_params)
out = shard_map_p.bind_with_trace(trace.parent_trace, (f_known, *in_consts), known_params)
in_fwd, out_fwd, out_knowns, out_avals_sharded, jaxpr, env = aux()
num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
out_consts, non_fwd_res = split_list(out, [len(out) - num_res])
@ -1581,7 +1580,7 @@ def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun,
f.debug_info)
f_primal = _promote_scalar_residuals_lin(f_primal, linearize_outs_thunk)
tangent_in_names = [ax for ax, nz in zip(in_names, nzs_in) if nz]
all_names = _all_mesh_names_except_spmd(mesh, auto, trace)
all_names = _all_newly_manual_mesh_names(mesh, auto, trace)
@as_hashable_function(closure=(linearize_outs_thunk))
def primal_out_names_thunk():
@ -1814,7 +1813,7 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
# prune inputs to jaxpr_known according to unks_in
mesh = params_known['mesh']
auto = params_known['auto']
all_names = _all_mesh_names_except_spmd(mesh, auto)
all_names = _all_newly_manual_mesh_names(mesh, auto)
in_names_known, _ = partition_list(unks_in, params_known['in_names'])
_, out_names_known = partition_list(kept_outs_known, params_known['out_names'])
out_names_known = out_names_known + [{0: all_names}] * sum(which)
@ -1838,10 +1837,19 @@ def _all_mesh_names_except_spmd(
) -> tuple[AxisName, ...]:
axis_env = core.get_axis_env()
spmd_names = axis_env.spmd_axis_names
axis_sizes = axis_env.axis_sizes
return tuple(name for name in mesh.axis_names if name not in spmd_names and
name not in auto)
# TODO(mattjj): remove this mechanism when we revise mesh scopes
def _all_newly_manual_mesh_names(
mesh: Mesh, auto: frozenset[AxisName], trace=None
) -> tuple[AxisName, ...]:
axis_env = core.get_axis_env()
spmd_names = axis_env.spmd_axis_names
axis_sizes = axis_env.axis_sizes
return tuple(name for name in mesh.axis_names if name not in spmd_names and
name not in auto and name not in axis_sizes)
# DCE
# TODO(mattjj): de-duplicate with pe.dce_jaxpr_call_rule, and/or _pmap_dce_rule?

View File

@ -2086,6 +2086,29 @@ class ShardMapTest(jtu.JaxTestCase):
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
self.assertAllClose(v*2, jax.grad(f)(v), check_dtypes=False)
def test_grad_nested_partial_auto_with_residuals(self):
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
def g(x):
return x * x * x
def h(x):
return shard_map(g, mesh,
in_specs=P(None, 'j'),
out_specs=P(None, 'j'))(x)
@jax.jit
def f(x):
return shard_map(h, mesh,
in_specs=P('i', None),
out_specs=P('i', None),
check_rep=False,
auto=frozenset({'j'}))(x).sum()
v = jnp.arange(32.).reshape(4, 8)
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
self.assertAllClose(v*v*3, jax.grad(f)(v), check_dtypes=False)
def test_axis_size_1_partial_auto(self):
mesh = jtu.create_mesh((1, 2, 2), ('i', 'j', 'k'))