mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[sharding_in_types] Add broadcast_in_dim rule.
PiperOrigin-RevId: 687054181
This commit is contained in:
parent
93389ab5f4
commit
e92e1191b3
@ -1817,7 +1817,7 @@ class ShapedArray(UnshapedArray):
|
||||
dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else
|
||||
self.dtype.name)
|
||||
dt_str = dt_str.replace('void', 'float0')
|
||||
if hasattr(self, 'sharding'):
|
||||
if hasattr(self, 'sharding') and self.sharding is not None:
|
||||
shapestr = ','.join(_get_shape_sharding_str(self.shape, self.sharding.spec))
|
||||
return f'{dt_str}[{shapestr}]'
|
||||
else:
|
||||
|
@ -3928,9 +3928,15 @@ def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions):
|
||||
msg = ("broadcast_in_dim broadcast_dimensions must be strictly increasing; "
|
||||
"got broadcast_dimensions {}")
|
||||
raise TypeError(msg.format(broadcast_dimensions))
|
||||
|
||||
return shape
|
||||
|
||||
def _broadcast_in_dim_sharding_rule(operand, *, shape, broadcast_dimensions):
|
||||
bds = set(broadcast_dimensions)
|
||||
orig_spec = iter(operand.sharding.spec)
|
||||
new_spec = [next(orig_spec) if i in bds else None for i in range(len(shape))]
|
||||
assert next(orig_spec, None) is None
|
||||
return NamedSharding(operand.sharding.mesh, P(*new_spec))
|
||||
|
||||
def _broadcast_in_dim_typecheck_rule(
|
||||
_, operand, *dyn_shape, shape, broadcast_dimensions):
|
||||
if not dyn_shape:
|
||||
@ -4079,10 +4085,12 @@ def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions) ->
|
||||
aval_out, = ctx.avals_out
|
||||
if dyn_shape:
|
||||
aval_out = aval_out.update(shape=_merge_dyn_shape(shape, dyn_shape))
|
||||
|
||||
|
||||
return [mlir.broadcast_in_dim(ctx, x, aval_out,
|
||||
broadcast_dimensions=broadcast_dimensions)]
|
||||
out = mlir.broadcast_in_dim(ctx, x, aval_out,
|
||||
broadcast_dimensions=broadcast_dimensions)
|
||||
if config.sharding_in_types.value:
|
||||
proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
|
||||
return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)]
|
||||
return [out]
|
||||
|
||||
def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions):
|
||||
if (not dyn_shape and
|
||||
@ -4090,7 +4098,12 @@ def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions):
|
||||
type(core.get_aval(d).dtype) is core.bint for d in shape)):
|
||||
shape = _broadcast_in_dim_shape_rule( # error checking
|
||||
x, shape=shape, broadcast_dimensions=broadcast_dimensions)
|
||||
return core.ShapedArray(shape, x.dtype, x.weak_type)
|
||||
if config.sharding_in_types.value:
|
||||
sharding = _broadcast_in_dim_sharding_rule(
|
||||
x, shape=shape, broadcast_dimensions=broadcast_dimensions)
|
||||
else:
|
||||
sharding = None
|
||||
return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=sharding)
|
||||
# If any BInts in shape, or Tracers in dyn_shape, produce a DShapedArray
|
||||
# (even if x is a ShapedArray)
|
||||
# TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code
|
||||
|
@ -4784,6 +4784,33 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
if reduce and compiled_text is not None:
|
||||
self.assertIn('all-reduce', compiled_text)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('0', 0, P(None, 'x', 'y')),
|
||||
('1', 1, P('x', None, 'y')),
|
||||
('2', 2, P('x', 'y', None)),
|
||||
('-1', -1, P('x', 'y', None)),
|
||||
)
|
||||
def test_broadcast_in_dim(self, axis, out_spec):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
out = jnp.expand_dims(arr, axis=axis)
|
||||
self.assertEqual(out.aval.sharding.spec, out_spec)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = jnp.expand_dims(x, axis=axis)
|
||||
self.assertEqual(y.sharding.spec, out_spec)
|
||||
return y
|
||||
|
||||
out = f(arr)
|
||||
self.assertEqual(out.aval.sharding.spec, out_spec)
|
||||
|
||||
lowered_text = f.lower(arr).as_text()
|
||||
self.assertIn('@Sharding', lowered_text)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class PJitErrorTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user