mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[sharding_in_types] Add sharding rule for reduce sum which is just drop the specs for the axis we are reducing over
PiperOrigin-RevId: 685069065
This commit is contained in:
parent
89fcd9f1f1
commit
5b8775dc2f
@ -4781,6 +4781,12 @@ def _reduce_number_dtype_rule(name, operand, *args, **kw):
|
||||
def _reduce_sum_shape_rule(operand, *, axes):
|
||||
return _reduce_op_shape_rule(operand, axes=axes)
|
||||
|
||||
def _reduce_sum_sharding_rule(operand, *, axes):
|
||||
axes = frozenset(axes)
|
||||
new_spec = P(*tuple(s for i, s in enumerate(operand.sharding.spec)
|
||||
if i not in axes))
|
||||
return NamedSharding(operand.sharding.mesh, new_spec)
|
||||
|
||||
def _reduce_sum_transpose_rule(cotangent, operand, *, axes):
|
||||
assert ad.is_undefined_primal(operand)
|
||||
input_shape = operand.aval.shape
|
||||
@ -4806,7 +4812,7 @@ def _replace_masked_values(x, val, padded_axes):
|
||||
|
||||
reduce_sum_p = standard_primitive(
|
||||
_reduce_sum_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'),
|
||||
'reduce_sum')
|
||||
'reduce_sum', sharding_rule=_reduce_sum_sharding_rule)
|
||||
ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule)
|
||||
batching.defreducer(reduce_sum_p, _get_sum_identity)
|
||||
pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, _reduce_sum,
|
||||
|
@ -4607,7 +4607,7 @@ def spec_regex(s):
|
||||
class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
def test_basic_mul(self):
|
||||
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
@ -4758,6 +4758,34 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
aval = aval.update(sharding=NamedSharding(mesh, P(('x', 'y'), None)))
|
||||
self.assertEqual(aval.str_short(), 'float32[8@xy,2]')
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('all', None,P('x', 'y'), P()),
|
||||
('first', 0, P('x', 'y'), P('y')),
|
||||
('second', 1, P('x', 'y'), P('x')),
|
||||
('first2', 0, P(('x', 'y'), None), P(None)),
|
||||
('second2', 1, P(('x', 'y'), None), P(('x', 'y')), False),
|
||||
)
|
||||
def test_reduce_sum(self, axis, in_spec, out_spec, reduce=True):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s = NamedSharding(mesh, in_spec)
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
self.assertEqual(x.sharding.spec, s.spec)
|
||||
y = jnp.sum(x, axis=axis)
|
||||
self.assertEqual(y.sharding.spec, out_spec)
|
||||
return y
|
||||
|
||||
out = f(arr)
|
||||
self.assertArraysEqual(out, np.sum(np_inp, axis=axis))
|
||||
self.assertEqual(out.aval.sharding.spec, out_spec)
|
||||
|
||||
compiled_text = f.lower(arr).compile().as_text()
|
||||
if reduce and compiled_text is not None:
|
||||
self.assertIn('all-reduce', compiled_text)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class PJitErrorTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user