mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
[sharding_in_types] Add reduce max, integer_pow and standard_unop sharding rules
PiperOrigin-RevId: 687073144
This commit is contained in:
parent
e92e1191b3
commit
5df4878ad0
@ -1995,11 +1995,14 @@ def unop_dtype_rule(result_dtype, accepted_dtypes, name, aval, **kwargs):
|
||||
|
||||
def unop(result_dtype, accepted_dtypes, name):
|
||||
dtype_rule = partial(unop_dtype_rule, result_dtype, accepted_dtypes, name)
|
||||
prim = standard_primitive(_attrgetter('shape'), dtype_rule, name)
|
||||
prim = standard_primitive(_attrgetter('shape'), dtype_rule, name,
|
||||
sharding_rule=_attrgetter('sharding'))
|
||||
batching.defvectorized(prim)
|
||||
pe.def_trivial_padding(prim)
|
||||
return prim
|
||||
|
||||
standard_unop = partial(unop, _identity)
|
||||
|
||||
_attrgetter = lambda name: lambda x, **kwargs: getattr(x, name)
|
||||
|
||||
|
||||
@ -2584,7 +2587,8 @@ def _integer_pow_jvp(g, x, *, y):
|
||||
return _zeros(g) if y == 0 else mul(g, mul(_const(x, y), integer_pow(x, y - 1)))
|
||||
|
||||
integer_pow_p = standard_primitive(
|
||||
_attrgetter('shape'), _integer_pow_dtype_rule, 'integer_pow')
|
||||
_attrgetter('shape'), _integer_pow_dtype_rule, 'integer_pow',
|
||||
sharding_rule=_attrgetter('sharding'))
|
||||
batching.defvectorized(integer_pow_p)
|
||||
ad.defjvp(integer_pow_p, _integer_pow_jvp)
|
||||
pe.def_trivial_padding(integer_pow_p)
|
||||
@ -2611,9 +2615,9 @@ def _integer_pow_lowering(ctx, x, *, y):
|
||||
# These cases are subsumed by the general case, but it's faster to emit these
|
||||
# common cases directly.
|
||||
if y == 2:
|
||||
return (hlo.multiply(x, x),)
|
||||
out = hlo.multiply(x, x)
|
||||
elif y == 3:
|
||||
return (hlo.multiply(hlo.multiply(x, x), x),)
|
||||
out = hlo.multiply(hlo.multiply(x, x), x)
|
||||
else:
|
||||
lowering = mlir.lower_fun(_integer_pow, multiple_results=False)
|
||||
# TODO(b/217551391): emitting an out-of-line call leads to a large
|
||||
@ -2621,7 +2625,13 @@ def _integer_pow_lowering(ctx, x, *, y):
|
||||
# clones the callee. Consider unconditionally caching when the MLIR->HLO
|
||||
# lowering doesn't expand the program.
|
||||
lowering = mlir.cache_lowering(lowering)
|
||||
return lowering(ctx, x, y=y)
|
||||
out = lowering(ctx, x, y=y)
|
||||
if config.sharding_in_types.value:
|
||||
aval_out, = ctx.avals_out
|
||||
proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
|
||||
out = out[0] if isinstance(out, list) else out
|
||||
return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)]
|
||||
return out if isinstance(out, list) else [out]
|
||||
|
||||
mlir.register_lowering(integer_pow_p, _integer_pow_lowering)
|
||||
|
||||
@ -4846,15 +4856,6 @@ def _reduce_number_dtype_rule(name, operand, *args, **kw):
|
||||
"of number.".format(name, dtype_to_string(operand.dtype)))
|
||||
return dtypes.canonicalize_dtype(operand.dtype)
|
||||
|
||||
def _reduce_sum_shape_rule(operand, *, axes):
|
||||
return _reduce_op_shape_rule(operand, axes=axes)
|
||||
|
||||
def _reduce_sum_sharding_rule(operand, *, axes):
|
||||
axes = frozenset(axes)
|
||||
new_spec = P(*tuple(s for i, s in enumerate(operand.sharding.spec)
|
||||
if i not in axes))
|
||||
return NamedSharding(operand.sharding.mesh, new_spec)
|
||||
|
||||
def _reduce_sum_transpose_rule(cotangent, operand, *, axes):
|
||||
assert ad.is_undefined_primal(operand)
|
||||
input_shape = operand.aval.shape
|
||||
@ -4877,16 +4878,6 @@ def _replace_masked_values(x, val, padded_axes):
|
||||
masks = [broadcasted_iota(dtype, x.shape, i) < d for i, d in padded_axes]
|
||||
return select(_reduce(operator.and_, masks), x, full_like(x, val))
|
||||
|
||||
|
||||
reduce_sum_p = standard_primitive(
|
||||
_reduce_sum_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'),
|
||||
'reduce_sum', sharding_rule=_reduce_sum_sharding_rule)
|
||||
ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule)
|
||||
batching.defreducer(reduce_sum_p, _get_sum_identity)
|
||||
pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, _reduce_sum,
|
||||
_get_sum_identity)
|
||||
|
||||
|
||||
def _reduce_op_shape_rule(operand, *, axes, input_shape=None):
|
||||
del input_shape # Unused.
|
||||
if len(axes) != len(set(axes)):
|
||||
@ -4896,6 +4887,20 @@ def _reduce_op_shape_rule(operand, *, axes, input_shape=None):
|
||||
axes = frozenset(axes)
|
||||
return tuple(d for i, d in enumerate(operand.shape) if i not in axes)
|
||||
|
||||
def _reduce_op_sharding_rule(operand, *, axes):
|
||||
axes = frozenset(axes)
|
||||
new_spec = P(*tuple(s for i, s in enumerate(operand.sharding.spec)
|
||||
if i not in axes))
|
||||
return NamedSharding(operand.sharding.mesh, new_spec)
|
||||
|
||||
reduce_sum_p = standard_primitive(
|
||||
_reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'),
|
||||
'reduce_sum', sharding_rule=_reduce_op_sharding_rule)
|
||||
ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule)
|
||||
batching.defreducer(reduce_sum_p, _get_sum_identity)
|
||||
pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, _reduce_sum,
|
||||
_get_sum_identity)
|
||||
|
||||
def _reduce_prod_jvp_rule(primals, tangents, *, axes):
|
||||
reducer = lambda x, y: [mul(x, y)]
|
||||
primals_out, tangents_out = _reduce_jvp(reducer, [_const(primals[0], 1)],
|
||||
@ -4922,8 +4927,9 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes):
|
||||
return div(_reduce_sum(mul(g, location_indicators), axes), counts)
|
||||
|
||||
|
||||
reduce_max_p = standard_primitive(_reduce_op_shape_rule, _input_dtype,
|
||||
'reduce_max')
|
||||
reduce_max_p = standard_primitive(
|
||||
_reduce_op_shape_rule, _input_dtype, 'reduce_max',
|
||||
sharding_rule=_reduce_op_sharding_rule)
|
||||
ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule)
|
||||
batching.defreducer(reduce_max_p, _get_max_identity)
|
||||
pe.padding_rules[reduce_max_p] = partial(_reducer_padding, _reduce_max,
|
||||
|
@ -4784,6 +4784,37 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
if reduce and compiled_text is not None:
|
||||
self.assertIn('all-reduce', compiled_text)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('all', None, P('x', 'y'), P()),
|
||||
('first', 0, P('x', 'y'), P('y')),
|
||||
('second', 1, P('x', 'y'), P('x')),
|
||||
('first2', 0, P(('x', 'y'), None), P(None)),
|
||||
('second2', 1, P(('x', 'y'), None), P(('x', 'y')), False),
|
||||
)
|
||||
def test_reduce_max(self, axis, in_spec, out_spec, reduce=True):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s = NamedSharding(mesh, in_spec)
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
self.assertEqual(x.sharding.spec, s.spec)
|
||||
y = jnp.max(x, axis=axis)
|
||||
self.assertEqual(y.sharding.spec, out_spec)
|
||||
return y
|
||||
|
||||
out = f(arr)
|
||||
self.assertArraysEqual(out, np.max(np_inp, axis=axis))
|
||||
self.assertEqual(out.aval.sharding.spec, out_spec)
|
||||
|
||||
lowered = f.lower(arr)
|
||||
self.assertIn('@Sharding', lowered.as_text())
|
||||
|
||||
compiled_text = lowered.compile().as_text()
|
||||
if reduce and compiled_text is not None:
|
||||
self.assertIn('all-reduce', compiled_text)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('0', 0, P(None, 'x', 'y')),
|
||||
('1', 1, P('x', None, 'y')),
|
||||
@ -4811,6 +4842,48 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
lowered_text = f.lower(arr).as_text()
|
||||
self.assertIn('@Sharding', lowered_text)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('2', 2),
|
||||
('3', 3),
|
||||
('4', 4),
|
||||
)
|
||||
def test_integer_pow(self, pow):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = x ** pow
|
||||
self.assertEqual(y.sharding.spec, s.spec)
|
||||
return y
|
||||
|
||||
out = f(arr)
|
||||
self.assertEqual(out.sharding, s)
|
||||
self.assertArraysEqual(out, np_inp ** pow)
|
||||
|
||||
lowered_text = f.lower(arr).as_text()
|
||||
self.assertIn('@Sharding', lowered_text)
|
||||
|
||||
def test_sin_unop(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
np_inp = np.arange(16.).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = lax.sin(x)
|
||||
self.assertEqual(y.sharding.spec, s.spec)
|
||||
return y
|
||||
|
||||
out = f(arr)
|
||||
self.assertEqual(out.sharding, s)
|
||||
|
||||
lowered_text = f.lower(arr).as_text()
|
||||
self.assertIn('@Sharding', lowered_text)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class PJitErrorTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user