Don't psum over auto mesh dims in _unmentioned2.

PiperOrigin-RevId: 698440525
This commit is contained in:
Parker Schuh 2024-11-20 10:35:16 -08:00 committed by jax authors
parent eab9026c14
commit 2c9b917b9d
2 changed files with 28 additions and 4 deletions

View File

@ -1547,10 +1547,11 @@ def _promote_scalar_residuals_jaxpr(jaxpr, which):
return jaxpr
def _unmentioned2(mesh: Mesh, names: AxisNames) -> list[AxisName]:
def _unmentioned2(mesh: Mesh, names: AxisNames,
auto: frozenset[AxisName]) -> list[AxisName]:
# We use a filtered-down version of unmentioned to avoid defensive-psum over
# more chips than required in the transpose-no-check-rep case.
name_set = {n for ns in names.values() for n in ns}
name_set = {n for ns in names.values() for n in ns} | auto
return [n for n in _all_mesh_names_except_spmd(mesh) if n not in name_set]
@ -1559,7 +1560,7 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
mb_div = lambda x, y: x / y if y != 1 else x
out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
else x if rewrite or dtypes.dtype(x) == dtypes.float0
else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns))))
else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, auto))))
for ns, x in zip(out_names, out_cts)]
args = [x if type(x) is not ad.UndefinedPrimal else
ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval))
@ -1577,7 +1578,7 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
)
out = [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
else x if rewrite
else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns)))
else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns, auto)))
for ns, x in zip(in_names, out)]
return out

View File

@ -2046,6 +2046,29 @@ class ShardMapTest(jtu.JaxTestCase):
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
self.assertAllClose(v*v, f(v), check_dtypes=False)
def test_grad_nested_partial_auto(self):
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
def g(x):
return x * x
def h(x):
return shard_map(g, mesh,
in_specs=P(None, 'j'),
out_specs=P(None, 'j'))(x)
@jax.jit
def f(x):
return shard_map(h, mesh,
in_specs=P('i', None),
out_specs=P('i', None),
check_rep=False,
auto=frozenset({'j'}))(x).sum()
v = jnp.arange(32.).reshape(4, 8)
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
self.assertAllClose(v*2, jax.grad(f)(v), check_dtypes=False)
def test_axis_size_1_partial_auto(self):
mesh = jtu.create_mesh((1, 2, 2), ('i', 'j', 'k'))