[sharding_in_types] Add broadcast_in_dim rule.

PiperOrigin-RevId: 687054181
This commit is contained in:
Yash Katariya 2024-10-17 14:54:26 -07:00 committed by jax authors
parent 93389ab5f4
commit e92e1191b3
3 changed files with 47 additions and 7 deletions

View File

@ -1817,7 +1817,7 @@ class ShapedArray(UnshapedArray):
dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else
self.dtype.name)
dt_str = dt_str.replace('void', 'float0')
if hasattr(self, 'sharding'):
if hasattr(self, 'sharding') and self.sharding is not None:
shapestr = ','.join(_get_shape_sharding_str(self.shape, self.sharding.spec))
return f'{dt_str}[{shapestr}]'
else:

View File

@ -3928,9 +3928,15 @@ def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions):
msg = ("broadcast_in_dim broadcast_dimensions must be strictly increasing; "
"got broadcast_dimensions {}")
raise TypeError(msg.format(broadcast_dimensions))
return shape
def _broadcast_in_dim_sharding_rule(operand, *, shape, broadcast_dimensions):
bds = set(broadcast_dimensions)
orig_spec = iter(operand.sharding.spec)
new_spec = [next(orig_spec) if i in bds else None for i in range(len(shape))]
assert next(orig_spec, None) is None
return NamedSharding(operand.sharding.mesh, P(*new_spec))
def _broadcast_in_dim_typecheck_rule(
_, operand, *dyn_shape, shape, broadcast_dimensions):
if not dyn_shape:
@ -4079,10 +4085,12 @@ def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions) ->
aval_out, = ctx.avals_out
if dyn_shape:
aval_out = aval_out.update(shape=_merge_dyn_shape(shape, dyn_shape))
return [mlir.broadcast_in_dim(ctx, x, aval_out,
broadcast_dimensions=broadcast_dimensions)]
out = mlir.broadcast_in_dim(ctx, x, aval_out,
broadcast_dimensions=broadcast_dimensions)
if config.sharding_in_types.value:
proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)]
return [out]
def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions):
if (not dyn_shape and
@ -4090,7 +4098,12 @@ def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions):
type(core.get_aval(d).dtype) is core.bint for d in shape)):
shape = _broadcast_in_dim_shape_rule( # error checking
x, shape=shape, broadcast_dimensions=broadcast_dimensions)
return core.ShapedArray(shape, x.dtype, x.weak_type)
if config.sharding_in_types.value:
sharding = _broadcast_in_dim_sharding_rule(
x, shape=shape, broadcast_dimensions=broadcast_dimensions)
else:
sharding = None
return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=sharding)
# If any BInts in shape, or Tracers in dyn_shape, produce a DShapedArray
# (even if x is a ShapedArray)
# TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code

View File

@ -4784,6 +4784,33 @@ class ShardingInTypesTest(jtu.JaxTestCase):
if reduce and compiled_text is not None:
self.assertIn('all-reduce', compiled_text)
@parameterized.named_parameters(
('0', 0, P(None, 'x', 'y')),
('1', 1, P('x', None, 'y')),
('2', 2, P('x', 'y', None)),
('-1', -1, P('x', 'y', None)),
)
def test_broadcast_in_dim(self, axis, out_spec):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
out = jnp.expand_dims(arr, axis=axis)
self.assertEqual(out.aval.sharding.spec, out_spec)
@jax.jit
def f(x):
y = jnp.expand_dims(x, axis=axis)
self.assertEqual(y.sharding.spec, out_spec)
return y
out = f(arr)
self.assertEqual(out.aval.sharding.spec, out_spec)
lowered_text = f.lower(arr).as_text()
self.assertIn('@Sharding', lowered_text)
@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):