mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Minimizes defensive psum
in shard_map transpose with check_rep=False
.
By summing up over fewer things, this version should be more numerically stable. PiperOrigin-RevId: 645499243
This commit is contained in:
parent
9e30079dba
commit
694cafb72b
@ -404,6 +404,7 @@ def _unmentioned(mesh: Mesh, names: AxisNames) -> list[AxisName]:
|
||||
name_set = {n for ns in names.values() for n in ns}
|
||||
return [n for n in mesh.axis_names if n not in name_set]
|
||||
|
||||
|
||||
def _try_infer_args(f, tree):
|
||||
dummy_args = tree_unflatten(tree, [False] * tree.num_leaves)
|
||||
try:
|
||||
@ -1481,13 +1482,21 @@ def _promote_scalar_residuals_jaxpr(jaxpr, which):
|
||||
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(fun, in_avals)
|
||||
return jaxpr
|
||||
|
||||
|
||||
def _unmentioned2(mesh: Mesh, names: AxisNames) -> list[AxisName]:
|
||||
# We use a filtered-down version of unmentioned to avoid defensive-psum over
|
||||
# more chips than required in the transpose-no-check-rep case.
|
||||
name_set = {n for ns in names.values() for n in ns}
|
||||
return [n for n in _all_mesh_names(mesh) if n not in name_set]
|
||||
|
||||
|
||||
def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
|
||||
check_rep, rewrite, auto):
|
||||
mb_div = lambda x, y: x / y if y != 1 else x
|
||||
out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
|
||||
else x if rewrite
|
||||
else mb_div(x, prod(map(mesh.shape.get, _unmentioned(mesh, ns))))
|
||||
for ns, x in zip(out_names, out_cts)]
|
||||
else x if rewrite
|
||||
else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns))))
|
||||
for ns, x in zip(out_names, out_cts)]
|
||||
args = [x if type(x) is not ad.UndefinedPrimal else
|
||||
ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval))
|
||||
for ns, x in zip(in_names, args)]
|
||||
@ -1503,8 +1512,9 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
|
||||
jaxpr_unknown.jaxpr, False, (), (*res_reshaped, *undefs), out_cts
|
||||
)
|
||||
out = [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
|
||||
else x if rewrite else jax.lax.psum(x, tuple(_unmentioned(mesh, ns)))
|
||||
for ns, x in zip(in_names, out)]
|
||||
else x if rewrite
|
||||
else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns)))
|
||||
for ns, x in zip(in_names, out)]
|
||||
return out
|
||||
|
||||
fun_trans, nz_arg_cts = ad.nonzero_outputs(fun_trans)
|
||||
|
Loading…
x
Reference in New Issue
Block a user