mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Only add new manual axes to residuals when adding axes with partial_auto.
PiperOrigin-RevId: 728839349
This commit is contained in:
parent
dbb46e9214
commit
b7c66bd22e
@ -1525,7 +1525,7 @@ def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p,
|
||||
in_pvals = [t.pval for t in tracers]
|
||||
in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals)
|
||||
unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names)
|
||||
all_names = _all_mesh_names_except_spmd(mesh, auto, trace)
|
||||
all_names = _all_newly_manual_mesh_names(mesh, auto, trace)
|
||||
in_avals_sharded = map(partial(_shard_aval, mesh), unk_in_names, in_avals)
|
||||
f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.tag, f.debug_info, False)
|
||||
f = _promote_scalar_residuals(f)
|
||||
@ -1542,8 +1542,7 @@ def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p,
|
||||
known_params = dict(mesh=mesh, in_names=(*known_in_names,),
|
||||
out_names_thunk=known_out_names, check_rep=check_rep,
|
||||
rewrite=rewrite, auto=auto)
|
||||
with _extend_axis_env(mesh, auto):
|
||||
out = shard_map_p.bind_with_trace(trace.parent_trace, (f_known, *in_consts), known_params)
|
||||
out = shard_map_p.bind_with_trace(trace.parent_trace, (f_known, *in_consts), known_params)
|
||||
in_fwd, out_fwd, out_knowns, out_avals_sharded, jaxpr, env = aux()
|
||||
num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
|
||||
out_consts, non_fwd_res = split_list(out, [len(out) - num_res])
|
||||
@ -1581,7 +1580,7 @@ def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun,
|
||||
f.debug_info)
|
||||
f_primal = _promote_scalar_residuals_lin(f_primal, linearize_outs_thunk)
|
||||
tangent_in_names = [ax for ax, nz in zip(in_names, nzs_in) if nz]
|
||||
all_names = _all_mesh_names_except_spmd(mesh, auto, trace)
|
||||
all_names = _all_newly_manual_mesh_names(mesh, auto, trace)
|
||||
|
||||
@as_hashable_function(closure=(linearize_outs_thunk))
|
||||
def primal_out_names_thunk():
|
||||
@ -1814,7 +1813,7 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
|
||||
# prune inputs to jaxpr_known according to unks_in
|
||||
mesh = params_known['mesh']
|
||||
auto = params_known['auto']
|
||||
all_names = _all_mesh_names_except_spmd(mesh, auto)
|
||||
all_names = _all_newly_manual_mesh_names(mesh, auto)
|
||||
in_names_known, _ = partition_list(unks_in, params_known['in_names'])
|
||||
_, out_names_known = partition_list(kept_outs_known, params_known['out_names'])
|
||||
out_names_known = out_names_known + [{0: all_names}] * sum(which)
|
||||
@ -1838,10 +1837,19 @@ def _all_mesh_names_except_spmd(
|
||||
) -> tuple[AxisName, ...]:
|
||||
axis_env = core.get_axis_env()
|
||||
spmd_names = axis_env.spmd_axis_names
|
||||
axis_sizes = axis_env.axis_sizes
|
||||
return tuple(name for name in mesh.axis_names if name not in spmd_names and
|
||||
name not in auto)
|
||||
|
||||
# TODO(mattjj): remove this mechanism when we revise mesh scopes
|
||||
def _all_newly_manual_mesh_names(
|
||||
mesh: Mesh, auto: frozenset[AxisName], trace=None
|
||||
) -> tuple[AxisName, ...]:
|
||||
axis_env = core.get_axis_env()
|
||||
spmd_names = axis_env.spmd_axis_names
|
||||
axis_sizes = axis_env.axis_sizes
|
||||
return tuple(name for name in mesh.axis_names if name not in spmd_names and
|
||||
name not in auto and name not in axis_sizes)
|
||||
|
||||
# DCE
|
||||
|
||||
# TODO(mattjj): de-duplicate with pe.dce_jaxpr_call_rule, and/or _pmap_dce_rule?
|
||||
|
@ -2086,6 +2086,29 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
|
||||
self.assertAllClose(v*2, jax.grad(f)(v), check_dtypes=False)
|
||||
|
||||
def test_grad_nested_partial_auto_with_residuals(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
|
||||
|
||||
def g(x):
|
||||
return x * x * x
|
||||
|
||||
def h(x):
|
||||
return shard_map(g, mesh,
|
||||
in_specs=P(None, 'j'),
|
||||
out_specs=P(None, 'j'))(x)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return shard_map(h, mesh,
|
||||
in_specs=P('i', None),
|
||||
out_specs=P('i', None),
|
||||
check_rep=False,
|
||||
auto=frozenset({'j'}))(x).sum()
|
||||
|
||||
v = jnp.arange(32.).reshape(4, 8)
|
||||
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
|
||||
self.assertAllClose(v*v*3, jax.grad(f)(v), check_dtypes=False)
|
||||
|
||||
def test_axis_size_1_partial_auto(self):
|
||||
mesh = jtu.create_mesh((1, 2, 2), ('i', 'j', 'k'))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user