mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Don't psum over auto mesh dims in _unmentioned2.
PiperOrigin-RevId: 698440525
This commit is contained in:
parent
eab9026c14
commit
2c9b917b9d
@ -1547,10 +1547,11 @@ def _promote_scalar_residuals_jaxpr(jaxpr, which):
|
||||
return jaxpr
|
||||
|
||||
|
||||
def _unmentioned2(mesh: Mesh, names: AxisNames) -> list[AxisName]:
|
||||
def _unmentioned2(mesh: Mesh, names: AxisNames,
|
||||
auto: frozenset[AxisName]) -> list[AxisName]:
|
||||
# We use a filtered-down version of unmentioned to avoid defensive-psum over
|
||||
# more chips than required in the transpose-no-check-rep case.
|
||||
name_set = {n for ns in names.values() for n in ns}
|
||||
name_set = {n for ns in names.values() for n in ns} | auto
|
||||
return [n for n in _all_mesh_names_except_spmd(mesh) if n not in name_set]
|
||||
|
||||
|
||||
@ -1559,7 +1560,7 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
|
||||
mb_div = lambda x, y: x / y if y != 1 else x
|
||||
out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
|
||||
else x if rewrite or dtypes.dtype(x) == dtypes.float0
|
||||
else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns))))
|
||||
else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, ns, auto))))
|
||||
for ns, x in zip(out_names, out_cts)]
|
||||
args = [x if type(x) is not ad.UndefinedPrimal else
|
||||
ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval))
|
||||
@ -1577,7 +1578,7 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
|
||||
)
|
||||
out = [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
|
||||
else x if rewrite
|
||||
else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns)))
|
||||
else jax.lax.psum(x, tuple(_unmentioned2(mesh, ns, auto)))
|
||||
for ns, x in zip(in_names, out)]
|
||||
return out
|
||||
|
||||
|
@ -2046,6 +2046,29 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
|
||||
self.assertAllClose(v*v, f(v), check_dtypes=False)
|
||||
|
||||
def test_grad_nested_partial_auto(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
|
||||
|
||||
def g(x):
|
||||
return x * x
|
||||
|
||||
def h(x):
|
||||
return shard_map(g, mesh,
|
||||
in_specs=P(None, 'j'),
|
||||
out_specs=P(None, 'j'))(x)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return shard_map(h, mesh,
|
||||
in_specs=P('i', None),
|
||||
out_specs=P('i', None),
|
||||
check_rep=False,
|
||||
auto=frozenset({'j'}))(x).sum()
|
||||
|
||||
v = jnp.arange(32.).reshape(4, 8)
|
||||
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
|
||||
self.assertAllClose(v*2, jax.grad(f)(v), check_dtypes=False)
|
||||
|
||||
def test_axis_size_1_partial_auto(self):
|
||||
mesh = jtu.create_mesh((1, 2, 2), ('i', 'j', 'k'))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user