1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 05:16:06 +00:00

Minimizes defensive psum in shard_map transpose with check_rep=False.

By summing up over fewer things, this version should be more numerically stable.

PiperOrigin-RevId: 645499243
This commit is contained in:
Keith Rush 2024-06-21 14:17:37 -07:00 committed by jax authors
parent 9e30079dba
commit 694cafb72b

@ -404,6 +404,7 @@ def _unmentioned(mesh: Mesh, names: AxisNames) -> list[AxisName]:
name_set = {n for ns in names.values() for n in ns}
return [n for n in mesh.axis_names if n not in name_set]
def _try_infer_args(f, tree):
dummy_args = tree_unflatten(tree, [False] * tree.num_leaves)
try:
@ -1481,13 +1482,21 @@ def _promote_scalar_residuals_jaxpr(jaxpr, which):
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(fun, in_avals)
return jaxpr
def _unmentioned2(mesh: Mesh, names: AxisNames) -> list[AxisName]:
# We use a filtered-down version of unmentioned to avoid defensive-psum over
# more chips than required in the transpose-no-check-rep case.
name_set = {n for ns in names.values() for n in ns}
return [n for n in _all_mesh_names(mesh) if n not in name_set]
def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
check_rep, rewrite, auto):
mb_div = lambda x, y: x / y if y != 1 else x
out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
else x if rewrite
else mb_div(x, prod(map(mesh.shape.get, _unmentioned(mesh, ns))))
for ns, x in zip(out_names, out_cts)]
else x if rewrite
else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns))))
for ns, x in zip(out_names, out_cts)]
args = [x if type(x) is not ad.UndefinedPrimal else
ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval))
for ns, x in zip(in_names, args)]
@ -1503,8 +1512,9 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
jaxpr_unknown.jaxpr, False, (), (*res_reshaped, *undefs), out_cts
)
out = [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
else x if rewrite else jax.lax.psum(x, tuple(_unmentioned(mesh, ns)))
for ns, x in zip(in_names, out)]
else x if rewrite
else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns)))
for ns, x in zip(in_names, out)]
return out
fun_trans, nz_arg_cts = ad.nonzero_outputs(fun_trans)