mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Rename psum2
to psum_invariant
and put it in lax_parallel
. We shouldn't expose this to public API and have users use psum
instead which will dispatch to psum_invariant
when check_rep=True
.
PiperOrigin-RevId: 745352875
This commit is contained in:
parent
84016bc368
commit
f95f6a8bdb
@ -1257,6 +1257,7 @@ reducing_transposes: dict[core.Primitive, Callable] = {}
|
||||
########################### pvary ##################################
|
||||
|
||||
def _pvary_transpose_rule(cts, *_, axes, axis_index_groups):
|
||||
from jax.experimental.shard_map import psum2_p
|
||||
return psum2_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups)
|
||||
from jax._src.lax import parallel as lax_parallel
|
||||
return lax_parallel.psum_invariant_p.bind(
|
||||
*cts, axes=axes, axis_index_groups=axis_index_groups)
|
||||
deflinear2(core.pvary_p, _pvary_transpose_rule)
|
||||
|
@ -145,25 +145,24 @@ def psum(x, axis_name, *, axis_index_groups=None):
|
||||
out_flat = tuple(lax._const(leaf, size) * pos_reduce(leaf) for leaf in leaves)
|
||||
else:
|
||||
if config.varying_axes_in_types.value and config._check_rep.value:
|
||||
out_flat = bind_psum2_p(leaves, axes=tuple(axis_name),
|
||||
axis_index_groups=axis_index_groups)
|
||||
out_flat = bind_psum_invariant(
|
||||
leaves, axes=tuple(axis_name), axis_index_groups=axis_index_groups)
|
||||
else:
|
||||
out_flat = psum_p.bind(
|
||||
*leaves, axes=tuple(axis_name), axis_index_groups=axis_index_groups)
|
||||
return tree_util.tree_unflatten(treedef, out_flat)
|
||||
|
||||
def bind_psum2_p(leaves, *, axes, axis_index_groups):
|
||||
def bind_psum_invariant(leaves, *, axes, axis_index_groups):
|
||||
if axis_index_groups is not None:
|
||||
raise NotImplementedError
|
||||
|
||||
from jax.experimental.shard_map import psum2_p
|
||||
axes_ = frozenset(axes)
|
||||
args_ = []
|
||||
for x in leaves:
|
||||
in_vma = core.get_aval(x).vma
|
||||
args_.append(pvary(x, tuple(pbroadcast_names))
|
||||
if (pbroadcast_names := axes_ - in_vma) else x)
|
||||
return psum2_p.bind(*args_, axes=axes, axis_index_groups=axis_index_groups)
|
||||
return psum_invariant_p.bind(*args_, axes=axes,
|
||||
axis_index_groups=axis_index_groups)
|
||||
|
||||
|
||||
def pmean(x, axis_name, *, axis_index_groups=None):
|
||||
@ -827,7 +826,7 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups):
|
||||
]
|
||||
return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes}
|
||||
|
||||
def _psum2_abstract_eval(name, *args, axes, axis_index_groups):
|
||||
def _psum_invariant_abstract_eval(name, *args, axes, axis_index_groups):
|
||||
if not config.varying_axes_in_types.value:
|
||||
return psum_p.abstract_eval(
|
||||
*args, axes=axes, axis_index_groups=axis_index_groups)
|
||||
@ -864,7 +863,7 @@ def _psum2_abstract_eval(name, *args, axes, axis_index_groups):
|
||||
]
|
||||
return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes}
|
||||
|
||||
# TODO(yashkatariya): Replace this with _psum2_abstract_eval
|
||||
# TODO(yashkatariya): Replace this with _psum_invariant_abstract_eval
|
||||
def _pmin_pmax_abstract_eval(name, *args, axes, axis_index_groups):
|
||||
if not config.varying_axes_in_types.value:
|
||||
return _allreduce_effectful_abstract_eval(
|
||||
@ -872,8 +871,8 @@ def _pmin_pmax_abstract_eval(name, *args, axes, axis_index_groups):
|
||||
if not config._check_rep.value:
|
||||
return _allreduce_effectful_abstract_eval(
|
||||
*args, axes=axes, axis_index_groups=axis_index_groups)
|
||||
return _psum2_abstract_eval(name, *args, axes=axes,
|
||||
axis_index_groups=axis_index_groups)
|
||||
return _psum_invariant_abstract_eval(
|
||||
name, *args, axes=axes, axis_index_groups=axis_index_groups)
|
||||
|
||||
def _check_axis_names(axes):
|
||||
named_axes = tuple(axis for axis in axes if not isinstance(axis, int))
|
||||
@ -1998,3 +1997,19 @@ mlir.register_lowering(pgather_p, _pgather_parallel_lowering)
|
||||
# TODO: Transpose? That requires adding pscatter...
|
||||
batching.fancy_primitive_batchers[pgather_p] = _pgather_collective_batcher
|
||||
batching.skippable_batchers[pgather_p] = partial(_names_in_param, 'axes')
|
||||
|
||||
psum_invariant_p = core.Primitive('psum_invariant')
|
||||
psum_invariant_p.multiple_results = True
|
||||
psum_invariant_p.def_impl(psum_p.impl)
|
||||
psum_invariant_p.def_effectful_abstract_eval(
|
||||
partial(_psum_invariant_abstract_eval, psum_invariant_p.name))
|
||||
mlir.register_lowering(psum_invariant_p, mlir._lowerings[psum_p])
|
||||
batching.fancy_primitive_batchers[psum_invariant_p] = partial(
|
||||
_batched_reduction_collective, psum_invariant_p,
|
||||
lambda v, axis_size: axis_size * v)
|
||||
batching.skippable_batchers[psum_invariant_p] = partial(_names_in_param, 'axes')
|
||||
|
||||
def _psum_invariant_transpose_rule(cts, *args, axes, axis_index_groups):
|
||||
del args
|
||||
return core.pvary_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups)
|
||||
ad.deflinear2(psum_invariant_p, _psum_invariant_transpose_rule)
|
||||
|
@ -174,6 +174,8 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
continue
|
||||
if p.name == "pvary":
|
||||
continue
|
||||
if p.name == "psum_invariant":
|
||||
continue
|
||||
if p.name == "sharding_constraint":
|
||||
continue
|
||||
if p.name == "dll_constraint":
|
||||
|
@ -467,7 +467,7 @@ roofline.register_roofline(lax_parallel.pmin_p)(_scalar_collective_roofline)
|
||||
roofline.register_roofline(lax_parallel.pmax_p)(_scalar_collective_roofline)
|
||||
|
||||
|
||||
@roofline.register_roofline(shard_map.psum2_p)
|
||||
@roofline.register_roofline(lax_parallel.psum_invariant_p)
|
||||
def _psum2_roofline(
|
||||
ctx: roofline.RooflineRuleContext,
|
||||
*args,
|
||||
|
@ -1141,26 +1141,6 @@ def _device_put_eager_rule(mesh, *xs, srcs, devices, copy_semantics):
|
||||
return xs
|
||||
eager_rules[dispatch.device_put_p] = _device_put_eager_rule
|
||||
|
||||
# New primitives for efficient transposition
|
||||
|
||||
# psum2_p is like psum_p except has a different transpose, so mostly copied:
|
||||
psum2_p = core.Primitive('psum2')
|
||||
psum2_p.multiple_results = True
|
||||
psum2_p.def_impl(lax_parallel.psum_p.impl)
|
||||
psum2_p.def_effectful_abstract_eval(
|
||||
partial(lax_parallel._psum2_abstract_eval, psum2_p.name))
|
||||
|
||||
mlir.register_lowering(psum2_p, mlir._lowerings[lax_parallel.psum_p])
|
||||
batching.fancy_primitive_batchers[psum2_p] = \
|
||||
partial(lax_parallel._batched_reduction_collective, psum2_p,
|
||||
lambda v, axis_size: axis_size * v)
|
||||
batching.skippable_batchers[psum2_p] = partial(lax_parallel._names_in_param, 'axes')
|
||||
|
||||
def _psum2_transpose_rule(cts, *args, axes, axis_index_groups):
|
||||
del args
|
||||
return pvary_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups)
|
||||
ad.deflinear2(psum2_p, _psum2_transpose_rule)
|
||||
|
||||
# Rewrite rules and static replication checking for efficient transposition
|
||||
|
||||
_rewrite_rules: dict[core.Primitive, Callable] = {}
|
||||
@ -1297,11 +1277,12 @@ def _psum_rewrite(mesh, in_rep, *args, axes, axis_index_groups):
|
||||
out_rep = [r | axes_ for r in in_rep] # TODO determinism (and elsewhere)
|
||||
args_ = [pvary(x, tuple(n for n in mesh.axis_names if n in axes_ & src))
|
||||
for x, src in zip(args, in_rep)]
|
||||
out_val = psum2_p.bind(*args_, axes=axes, axis_index_groups=axis_index_groups)
|
||||
out_val = lax_parallel.psum_invariant_p.bind(
|
||||
*args_, axes=axes, axis_index_groups=axis_index_groups)
|
||||
return out_val, out_rep
|
||||
|
||||
|
||||
@register_check(psum2_p)
|
||||
@register_check(lax_parallel.psum_invariant_p)
|
||||
def _psum2_check(mesh, *in_rep, axes, axis_index_groups):
|
||||
assert type(axes) is tuple
|
||||
if any(set(axes) & r for r in in_rep if r is not None):
|
||||
@ -1312,7 +1293,7 @@ def _psum2_check(mesh, *in_rep, axes, axis_index_groups):
|
||||
"workaround pass the check_rep=False argument to shard_map")
|
||||
in_rep = tuple(set(mesh.axis_names) if r is None else r for r in in_rep)
|
||||
return [r | set(axes) for r in in_rep]
|
||||
register_norewrite(psum2_p)
|
||||
register_norewrite(lax_parallel.psum_invariant_p)
|
||||
|
||||
|
||||
@register_check(pvary_p)
|
||||
@ -2342,8 +2323,8 @@ def _rewrite_bwd(bwd: lu.WrappedFun,
|
||||
|
||||
def _match_replication(src, dst, x):
|
||||
if dst - src:
|
||||
x, = psum2_p.bind(x, axes=tuple(n for n in dst if n not in src),
|
||||
axis_index_groups=None)
|
||||
x, = lax_parallel.psum_invariant_p.bind(
|
||||
x, axes=tuple(n for n in dst if n not in src), axis_index_groups=None)
|
||||
if src - dst:
|
||||
x = pvary(x, tuple(n for n in src if n not in dst))
|
||||
return x
|
||||
|
@ -1614,7 +1614,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(1.))[1])(jnp.arange(4.))
|
||||
e, = jaxpr.jaxpr.eqns
|
||||
e2, = e.params['jaxpr'].eqns
|
||||
self.assertEqual(str(e2.primitive), 'psum2')
|
||||
self.assertEqual(str(e2.primitive), 'psum_invariant')
|
||||
self.assertEqual(e2.params['axes'], ('x',))
|
||||
|
||||
def test_fanin_psum_transposes_to_fanout(self):
|
||||
@ -1639,7 +1639,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(4.))[1])(jnp.arange(4.))
|
||||
e, = jaxpr.jaxpr.eqns
|
||||
e1, e2 = e.params['jaxpr'].eqns
|
||||
self.assertEqual(str(e1.primitive), 'psum2')
|
||||
self.assertEqual(str(e1.primitive), 'psum_invariant')
|
||||
self.assertEqual(str(e2.primitive), 'pvary')
|
||||
|
||||
def test_transpose_float0(self):
|
||||
@ -1701,7 +1701,8 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertEqual(y.aval.vma, {'x'})
|
||||
return y
|
||||
|
||||
f(jnp.arange(8))
|
||||
f(jnp.arange(8.))
|
||||
jax.grad(lambda x: f(x).sum())(jnp.arange(8.))
|
||||
|
||||
def test_rewrite_binops(self):
|
||||
mesh = jtu.create_mesh((4,), ('x',))
|
||||
@ -1729,7 +1730,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
e, = jaxpr.jaxpr.eqns
|
||||
e, = e.params['jaxpr'].eqns
|
||||
e1, e2 = e.params['jaxpr'].eqns
|
||||
self.assertEqual(e1.primitive.name, 'psum2')
|
||||
self.assertEqual(e1.primitive.name, 'psum_invariant')
|
||||
self.assertEqual(e2.primitive.name, 'pvary')
|
||||
|
||||
def test_check_rep_false_grads(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user