[sharding_in_types] Add sharding rule for reduce sum which is just drop the specs for the axis we are reducing over

PiperOrigin-RevId: 685069065
This commit is contained in:
Yash Katariya 2024-10-11 21:30:30 -07:00 committed by jax authors
parent 89fcd9f1f1
commit 5b8775dc2f
2 changed files with 36 additions and 2 deletions

View File

@ -4781,6 +4781,12 @@ def _reduce_number_dtype_rule(name, operand, *args, **kw):
def _reduce_sum_shape_rule(operand, *, axes):
return _reduce_op_shape_rule(operand, axes=axes)
def _reduce_sum_sharding_rule(operand, *, axes):
axes = frozenset(axes)
new_spec = P(*tuple(s for i, s in enumerate(operand.sharding.spec)
if i not in axes))
return NamedSharding(operand.sharding.mesh, new_spec)
def _reduce_sum_transpose_rule(cotangent, operand, *, axes):
assert ad.is_undefined_primal(operand)
input_shape = operand.aval.shape
@ -4806,7 +4812,7 @@ def _replace_masked_values(x, val, padded_axes):
reduce_sum_p = standard_primitive(
_reduce_sum_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'),
'reduce_sum')
'reduce_sum', sharding_rule=_reduce_sum_sharding_rule)
ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule)
batching.defreducer(reduce_sum_p, _get_sum_identity)
pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, _reduce_sum,

View File

@ -4607,7 +4607,7 @@ def spec_regex(s):
class ShardingInTypesTest(jtu.JaxTestCase):
def test_basic_mul(self):
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@ -4758,6 +4758,34 @@ class ShardingInTypesTest(jtu.JaxTestCase):
aval = aval.update(sharding=NamedSharding(mesh, P(('x', 'y'), None)))
self.assertEqual(aval.str_short(), 'float32[8@xy,2]')
@parameterized.named_parameters(
('all', None,P('x', 'y'), P()),
('first', 0, P('x', 'y'), P('y')),
('second', 1, P('x', 'y'), P('x')),
('first2', 0, P(('x', 'y'), None), P(None)),
('second2', 1, P(('x', 'y'), None), P(('x', 'y')), False),
)
def test_reduce_sum(self, axis, in_spec, out_spec, reduce=True):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, in_spec)
arr = jax.device_put(np_inp, s)
@jax.jit
def f(x):
self.assertEqual(x.sharding.spec, s.spec)
y = jnp.sum(x, axis=axis)
self.assertEqual(y.sharding.spec, out_spec)
return y
out = f(arr)
self.assertArraysEqual(out, np.sum(np_inp, axis=axis))
self.assertEqual(out.aval.sharding.spec, out_spec)
compiled_text = f.lower(arr).compile().as_text()
if reduce and compiled_text is not None:
self.assertIn('all-reduce', compiled_text)
@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):