Rename psum2 to psum_invariant and put it in lax_parallel. We shouldn't expose this to public API and have users use psum instead which will dispatch to psum_invariant when check_rep=True.

PiperOrigin-RevId: 745352875
This commit is contained in:
Yash Katariya 2025-04-08 17:28:04 -07:00 committed by jax authors
parent 84016bc368
commit f95f6a8bdb
6 changed files with 42 additions and 42 deletions

View File

@ -1257,6 +1257,7 @@ reducing_transposes: dict[core.Primitive, Callable] = {}
########################### pvary ##################################
def _pvary_transpose_rule(cts, *_, axes, axis_index_groups):
from jax.experimental.shard_map import psum2_p
return psum2_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups)
from jax._src.lax import parallel as lax_parallel
return lax_parallel.psum_invariant_p.bind(
*cts, axes=axes, axis_index_groups=axis_index_groups)
deflinear2(core.pvary_p, _pvary_transpose_rule)

View File

@ -145,25 +145,24 @@ def psum(x, axis_name, *, axis_index_groups=None):
out_flat = tuple(lax._const(leaf, size) * pos_reduce(leaf) for leaf in leaves)
else:
if config.varying_axes_in_types.value and config._check_rep.value:
out_flat = bind_psum2_p(leaves, axes=tuple(axis_name),
axis_index_groups=axis_index_groups)
out_flat = bind_psum_invariant(
leaves, axes=tuple(axis_name), axis_index_groups=axis_index_groups)
else:
out_flat = psum_p.bind(
*leaves, axes=tuple(axis_name), axis_index_groups=axis_index_groups)
return tree_util.tree_unflatten(treedef, out_flat)
def bind_psum2_p(leaves, *, axes, axis_index_groups):
def bind_psum_invariant(leaves, *, axes, axis_index_groups):
if axis_index_groups is not None:
raise NotImplementedError
from jax.experimental.shard_map import psum2_p
axes_ = frozenset(axes)
args_ = []
for x in leaves:
in_vma = core.get_aval(x).vma
args_.append(pvary(x, tuple(pbroadcast_names))
if (pbroadcast_names := axes_ - in_vma) else x)
return psum2_p.bind(*args_, axes=axes, axis_index_groups=axis_index_groups)
return psum_invariant_p.bind(*args_, axes=axes,
axis_index_groups=axis_index_groups)
def pmean(x, axis_name, *, axis_index_groups=None):
@ -827,7 +826,7 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups):
]
return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes}
def _psum2_abstract_eval(name, *args, axes, axis_index_groups):
def _psum_invariant_abstract_eval(name, *args, axes, axis_index_groups):
if not config.varying_axes_in_types.value:
return psum_p.abstract_eval(
*args, axes=axes, axis_index_groups=axis_index_groups)
@ -864,7 +863,7 @@ def _psum2_abstract_eval(name, *args, axes, axis_index_groups):
]
return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes}
# TODO(yashkatariya): Replace this with _psum2_abstract_eval
# TODO(yashkatariya): Replace this with _psum_invariant_abstract_eval
def _pmin_pmax_abstract_eval(name, *args, axes, axis_index_groups):
if not config.varying_axes_in_types.value:
return _allreduce_effectful_abstract_eval(
@ -872,8 +871,8 @@ def _pmin_pmax_abstract_eval(name, *args, axes, axis_index_groups):
if not config._check_rep.value:
return _allreduce_effectful_abstract_eval(
*args, axes=axes, axis_index_groups=axis_index_groups)
return _psum2_abstract_eval(name, *args, axes=axes,
axis_index_groups=axis_index_groups)
return _psum_invariant_abstract_eval(
name, *args, axes=axes, axis_index_groups=axis_index_groups)
def _check_axis_names(axes):
named_axes = tuple(axis for axis in axes if not isinstance(axis, int))
@ -1998,3 +1997,19 @@ mlir.register_lowering(pgather_p, _pgather_parallel_lowering)
# TODO: Transpose? That requires adding pscatter...
batching.fancy_primitive_batchers[pgather_p] = _pgather_collective_batcher
batching.skippable_batchers[pgather_p] = partial(_names_in_param, 'axes')
psum_invariant_p = core.Primitive('psum_invariant')
psum_invariant_p.multiple_results = True
psum_invariant_p.def_impl(psum_p.impl)
psum_invariant_p.def_effectful_abstract_eval(
partial(_psum_invariant_abstract_eval, psum_invariant_p.name))
mlir.register_lowering(psum_invariant_p, mlir._lowerings[psum_p])
batching.fancy_primitive_batchers[psum_invariant_p] = partial(
_batched_reduction_collective, psum_invariant_p,
lambda v, axis_size: axis_size * v)
batching.skippable_batchers[psum_invariant_p] = partial(_names_in_param, 'axes')
def _psum_invariant_transpose_rule(cts, *args, axes, axis_index_groups):
del args
return core.pvary_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups)
ad.deflinear2(psum_invariant_p, _psum_invariant_transpose_rule)

View File

@ -174,6 +174,8 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
continue
if p.name == "pvary":
continue
if p.name == "psum_invariant":
continue
if p.name == "sharding_constraint":
continue
if p.name == "dll_constraint":

View File

@ -467,7 +467,7 @@ roofline.register_roofline(lax_parallel.pmin_p)(_scalar_collective_roofline)
roofline.register_roofline(lax_parallel.pmax_p)(_scalar_collective_roofline)
@roofline.register_roofline(shard_map.psum2_p)
@roofline.register_roofline(lax_parallel.psum_invariant_p)
def _psum2_roofline(
ctx: roofline.RooflineRuleContext,
*args,

View File

@ -1141,26 +1141,6 @@ def _device_put_eager_rule(mesh, *xs, srcs, devices, copy_semantics):
return xs
eager_rules[dispatch.device_put_p] = _device_put_eager_rule
# New primitives for efficient transposition
# psum2_p is like psum_p except has a different transpose, so mostly copied:
psum2_p = core.Primitive('psum2')
psum2_p.multiple_results = True
psum2_p.def_impl(lax_parallel.psum_p.impl)
psum2_p.def_effectful_abstract_eval(
partial(lax_parallel._psum2_abstract_eval, psum2_p.name))
mlir.register_lowering(psum2_p, mlir._lowerings[lax_parallel.psum_p])
batching.fancy_primitive_batchers[psum2_p] = \
partial(lax_parallel._batched_reduction_collective, psum2_p,
lambda v, axis_size: axis_size * v)
batching.skippable_batchers[psum2_p] = partial(lax_parallel._names_in_param, 'axes')
def _psum2_transpose_rule(cts, *args, axes, axis_index_groups):
del args
return pvary_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups)
ad.deflinear2(psum2_p, _psum2_transpose_rule)
# Rewrite rules and static replication checking for efficient transposition
_rewrite_rules: dict[core.Primitive, Callable] = {}
@ -1297,11 +1277,12 @@ def _psum_rewrite(mesh, in_rep, *args, axes, axis_index_groups):
out_rep = [r | axes_ for r in in_rep] # TODO determinism (and elsewhere)
args_ = [pvary(x, tuple(n for n in mesh.axis_names if n in axes_ & src))
for x, src in zip(args, in_rep)]
out_val = psum2_p.bind(*args_, axes=axes, axis_index_groups=axis_index_groups)
out_val = lax_parallel.psum_invariant_p.bind(
*args_, axes=axes, axis_index_groups=axis_index_groups)
return out_val, out_rep
@register_check(psum2_p)
@register_check(lax_parallel.psum_invariant_p)
def _psum2_check(mesh, *in_rep, axes, axis_index_groups):
assert type(axes) is tuple
if any(set(axes) & r for r in in_rep if r is not None):
@ -1312,7 +1293,7 @@ def _psum2_check(mesh, *in_rep, axes, axis_index_groups):
"workaround pass the check_rep=False argument to shard_map")
in_rep = tuple(set(mesh.axis_names) if r is None else r for r in in_rep)
return [r | set(axes) for r in in_rep]
register_norewrite(psum2_p)
register_norewrite(lax_parallel.psum_invariant_p)
@register_check(pvary_p)
@ -2342,8 +2323,8 @@ def _rewrite_bwd(bwd: lu.WrappedFun,
def _match_replication(src, dst, x):
if dst - src:
x, = psum2_p.bind(x, axes=tuple(n for n in dst if n not in src),
axis_index_groups=None)
x, = lax_parallel.psum_invariant_p.bind(
x, axes=tuple(n for n in dst if n not in src), axis_index_groups=None)
if src - dst:
x = pvary(x, tuple(n for n in src if n not in dst))
return x

View File

@ -1614,7 +1614,7 @@ class ShardMapTest(jtu.JaxTestCase):
jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(1.))[1])(jnp.arange(4.))
e, = jaxpr.jaxpr.eqns
e2, = e.params['jaxpr'].eqns
self.assertEqual(str(e2.primitive), 'psum2')
self.assertEqual(str(e2.primitive), 'psum_invariant')
self.assertEqual(e2.params['axes'], ('x',))
def test_fanin_psum_transposes_to_fanout(self):
@ -1639,7 +1639,7 @@ class ShardMapTest(jtu.JaxTestCase):
jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(4.))[1])(jnp.arange(4.))
e, = jaxpr.jaxpr.eqns
e1, e2 = e.params['jaxpr'].eqns
self.assertEqual(str(e1.primitive), 'psum2')
self.assertEqual(str(e1.primitive), 'psum_invariant')
self.assertEqual(str(e2.primitive), 'pvary')
def test_transpose_float0(self):
@ -1701,7 +1701,8 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertEqual(y.aval.vma, {'x'})
return y
f(jnp.arange(8))
f(jnp.arange(8.))
jax.grad(lambda x: f(x).sum())(jnp.arange(8.))
def test_rewrite_binops(self):
mesh = jtu.create_mesh((4,), ('x',))
@ -1729,7 +1730,7 @@ class ShardMapTest(jtu.JaxTestCase):
e, = jaxpr.jaxpr.eqns
e, = e.params['jaxpr'].eqns
e1, e2 = e.params['jaxpr'].eqns
self.assertEqual(e1.primitive.name, 'psum2')
self.assertEqual(e1.primitive.name, 'psum_invariant')
self.assertEqual(e2.primitive.name, 'pvary')
def test_check_rep_false_grads(self):