Make sharding_in_types work with Shardy

PiperOrigin-RevId: 713479962
This commit is contained in:
Yash Katariya 2025-01-08 18:05:04 -08:00 committed by jax authors
parent fb832afc00
commit b2b38679e2
6 changed files with 116 additions and 49 deletions

View File

@ -49,7 +49,8 @@ from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.layout import AutoLayout, DeviceLocalLayout
from jax._src.sharding import Sharding as JSharding
from jax._src.sharding_impls import AUTO, NamedSharding
from jax._src.sharding_impls import (AUTO, NamedSharding,
modify_sdy_sharding_wrt_axis_types)
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
from jax._src.lib.mlir import dialects, ir, passmanager
@ -1689,13 +1690,17 @@ def lower_jaxpr_to_fun(
for o, s, o_aval in zip(flat_outputs, ir_result_shardings, output_avals)]
if ir_result_shardings is not None:
flat_outputs = [
wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s,
unspecified_dims=us[2])
if us[0] and not us[1] else o
for o, s, o_aval, us in zip(flat_outputs, ir_result_shardings,
output_avals, unconstrained_shardings) # type: ignore
]
temp_flat_outputs = []
for o, s, o_aval, us in zip(flat_outputs, ir_result_shardings,
output_avals, unconstrained_shardings): # type: ignore
if us[0] and not us[1]:
if config.use_shardy_partitioner.value and config.sharding_in_types.value:
s = modify_sdy_sharding_wrt_axis_types(s, o_aval.sharding.mesh)
temp_flat_outputs.append(wrap_with_sharding_op(
entry_lowering_ctx, o, o_aval, s, unspecified_dims=us[2]))
else:
temp_flat_outputs.append(o)
flat_outputs = temp_flat_outputs
# Insert a custom call if output is on host because XLA needs that to do the
# transfer.
@ -2594,14 +2599,20 @@ def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None):
return op
# TODO(yashkatariya): If all the axes in pspec are AUTO or collective,
# `return op` early and avoid bloating HLO size.
proto = (aval.sharding._to_xla_hlo_sharding(aval.ndim).to_proto()
if sharding_proto is None else sharding_proto)
unspecified_dims = None
if aval.sharding.mesh._any_axis_collective:
unspecified_dims = set(range(aval.ndim))
elif aval.sharding.mesh._any_axis_auto:
unspecified_dims = {i for i, s in enumerate(aval.sharding.spec) if s is None}
return wrap_with_sharding_op(ctx, op, aval, proto, unspecified_dims)
if config.use_shardy_partitioner.value:
proto = (aval.sharding._to_sdy_sharding(aval.ndim)
if sharding_proto is None else sharding_proto)
proto = modify_sdy_sharding_wrt_axis_types(proto, aval.sharding.mesh)
return wrap_with_sharding_op(ctx, op, aval, proto)
else:
proto = (aval.sharding._to_xla_hlo_sharding(aval.ndim).to_proto()
if sharding_proto is None else sharding_proto)
unspecified_dims = None
if aval.sharding.mesh._any_axis_auto:
# TODO(yashkatariya): Maybe if any mesh axis is auto, mark all axes
# as unspecified?
unspecified_dims = {i for i, s in enumerate(aval.sharding.spec) if s is None}
return wrap_with_sharding_op(ctx, op, aval, proto, unspecified_dims)
def set_sharding(op, sharding: xc.OpSharding | sharding_impls.SdyArraySharding):

View File

@ -2163,12 +2163,9 @@ def _concretize_abstract_out_shardings(shardings, avals, device_assignment):
out = []
for s, a in zip(shardings, avals):
if isinstance(s, UnspecifiedValue) and a.sharding is not None:
if config.use_shardy_partitioner.value:
spec = a.sharding.spec
else:
spec = (PartitionSpec(*[PartitionSpec.UNCONSTRAINED if sp is None else sp
for sp in a.sharding.spec])
if a.sharding.mesh._any_axis_auto else a.sharding.spec)
spec = (PartitionSpec(*[PartitionSpec.UNCONSTRAINED if sp is None else sp
for sp in a.sharding.spec])
if a.sharding.mesh._any_axis_auto else a.sharding.spec)
out.append(NamedSharding(
_abstract_to_concrete_mesh(a.sharding.mesh), spec))
else:

View File

@ -2467,8 +2467,7 @@ def multi_sharding_in_dim(ctx, ops, in_avals, out_aval):
if in_aval.sharding == out_aval.sharding or in_aval.sharding is None:
out.append(op)
else:
proto = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto()
out.append(mlir.lower_sharding_under_shit(ctx, op, out_aval, proto))
out.append(mlir.lower_sharding_under_shit(ctx, op, out_aval))
return out

View File

@ -2710,7 +2710,9 @@ ad.deflinear2(sharding_cast_p, _sharding_cast_transpose_rule)
def _sharding_cast_hlo_lowering(ctx, x_node, *, src_sharding, dst_sharding):
aval, = ctx.avals_in
aval_out, = ctx.avals_out
proto = dst_sharding._to_xla_hlo_sharding(aval.ndim).to_proto()
proto = (dst_sharding._to_sdy_sharding(aval.ndim)
if config.use_shardy_partitioner.value else
dst_sharding._to_xla_hlo_sharding(aval.ndim).to_proto())
return [mlir.lower_sharding_under_shit(ctx, x_node, aval_out, proto)]
mlir.register_lowering(sharding_cast_p, _sharding_cast_hlo_lowering)

View File

@ -142,6 +142,7 @@ class SdyArraySharding:
mesh_shape: tuple[tuple[str, int], ...] | None
dimension_shardings: Sequence[SdyDimSharding]
logical_device_ids: tuple[int, ...] | None = None
replicated_axes: tuple[str, ...] = ()
# NOTE: An MLIR context is required as a context manager.
def build(self) -> sdy.TensorShardingAttr:
@ -155,14 +156,17 @@ class SdyArraySharding:
ldi)
return sdy.TensorShardingAttr.get(
mesh_attr,
[dim_sharding.build() for dim_sharding in self.dimension_shardings])
[dim_sharding.build() for dim_sharding in self.dimension_shardings],
replicated_axes=[sdy.AxisRefAttr.get(axis) for axis in self.replicated_axes])
def __repr__(self):
dim_sharding_repr = ', '.join(
d._custom_repr() for d in self.dimension_shardings)
device_id_repr = (f', device_ids={self.logical_device_ids}'
if self.logical_device_ids is not None else '')
return f"SdyArraySharding([{dim_sharding_repr}]{device_id_repr})"
rar = (f', replicated_axes={self.replicated_axes}'
if self.replicated_axes else '')
return f"SdyArraySharding([{dim_sharding_repr}]{device_id_repr}{rar})"
@util.cache(max_size=4096, trace_context_in_key=False)
@ -425,6 +429,23 @@ class NamedSharding(jsharding.Sharding):
return SdyArraySharding(self.mesh.shape_tuple, dim_shardings,
self._logical_device_ids)
# TODO(yashkatariya): Upstream this into `_to_sdy_sharding` maybe with an extra
# parameter to it `_to_sdy_sharding(self, ndim, modify_wrt_axis_types=False)`
def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArraySharding, mesh):
if mesh._any_axis_auto:
dim_shardings, used_axes = [], [] # type: ignore
for d in sdy_sharding.dimension_shardings:
# TODO(yashkatariya): Maybe if any mesh axis is auto, mark all axes as open?
dim_shardings.append(SdyDimSharding(axes=[], is_closed=False)
if not d.axes and d.is_closed else d)
used_axes.extend(d.axes)
remaining_axes = set(mesh.axis_names) - set(used_axes)
replicated_axes = tuple(r for r in remaining_axes
if mesh._name_to_type[r] == mesh_lib.AxisTypes.User)
return SdyArraySharding(sdy_sharding.mesh_shape, dim_shardings,
sdy_sharding.logical_device_ids, replicated_axes)
return sdy_sharding
@util.cache(max_size=128, trace_context_in_key=False)
def get_replicated_hlo_sharding():

View File

@ -4729,9 +4729,14 @@ def spec_regex(s):
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
@jtu.with_config(jax_use_shardy_partitioner=False)
class ShardingInTypesTest(jtu.JaxTestCase):
def check_wsc_in_lowered(self, text):
if config.use_shardy_partitioner.value:
self.assertIn('sdy.sharding_constraint', text)
else:
self.assertIn('@Sharding', text)
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_basic_mul(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
@ -4753,7 +4758,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
lowered_text = f.lower(arr).as_text()
if config.use_shardy_partitioner.value:
self.assertIn('sdy.sharding_constraint', lowered_text)
self.assertEqual(lowered_text.count('sdy.sharding_constraint'), 3)
else:
self.assertEqual(lowered_text.count('@Sharding'), 3)
@ -4834,7 +4839,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out.sharding, NamedSharding(mesh, out_spec))
lowered = f.lower(arr1, arr2)
self.assertIn('@Sharding', lowered.as_text())
self.check_wsc_in_lowered(lowered.as_text())
compiled_text = lowered.compile().as_text()
if collective_name is not None and compiled_text is not None:
@ -4971,7 +4976,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out.sharding, NamedSharding(mesh, out_spec))
lowered = f.lower(arr)
self.assertIn('@Sharding', lowered.as_text())
self.check_wsc_in_lowered(lowered.as_text())
compiled_text = lowered.compile().as_text()
if reduce and compiled_text is not None:
@ -5002,7 +5007,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out.sharding, NamedSharding(mesh, out_spec))
lowered = f.lower(arr)
self.assertIn('@Sharding', lowered.as_text())
self.check_wsc_in_lowered(lowered.as_text())
compiled_text = lowered.compile().as_text()
if reduce and compiled_text is not None:
@ -5044,7 +5049,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out.sharding, NamedSharding(mesh, out_spec))
lowered_text = f.lower(arr).as_text()
self.assertIn('@Sharding', lowered_text)
self.check_wsc_in_lowered(lowered_text)
@parameterized.named_parameters(
('2', 2),
@ -5068,7 +5073,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertArraysEqual(out, np_inp ** pow)
lowered_text = f.lower(arr).as_text()
self.assertIn('@Sharding', lowered_text)
self.check_wsc_in_lowered(lowered_text)
@jtu.with_user_mesh((1,), 'x')
def test_broadcasting_nary_error(self, mesh):
@ -5102,7 +5107,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out.sharding, s)
lowered_text = f.lower(arr).as_text()
self.assertIn('@Sharding', lowered_text)
self.check_wsc_in_lowered(lowered_text)
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_jnp_array(self, mesh):
@ -5137,7 +5142,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'z', 'x')))
lowered_text = f.lower(arr).as_text()
self.assertIn('@Sharding', lowered_text)
self.check_wsc_in_lowered(lowered_text)
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_broadcasted_iota_with_sharding(self, mesh):
@ -5182,7 +5187,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
lowered_text = f.lower(arr1, arr2).as_text()
self.assertIn('@Sharding', lowered_text)
self.check_wsc_in_lowered(lowered_text)
@jax.jit
def g(x, y):
@ -5228,7 +5233,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None, 'y', None)))
lowered_text = h.lower(arr1, arr2).as_text()
self.assertIn('@Sharding', lowered_text)
self.check_wsc_in_lowered(lowered_text)
@jax.jit
def h2(x, y):
@ -5268,7 +5273,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertArraysEqual(out, np_inp.reshape(dst_shape) * 2)
lowered_text = f.lower(arr, new_s).as_text()
self.assertIn('@Sharding', lowered_text)
self.check_wsc_in_lowered(lowered_text)
def g(x):
out = f(x, new_s)
@ -5295,7 +5300,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertArraysEqual(out, arr1)
lowered_text = f.lower(arr1 == arr2, arr1, arr2).as_text()
self.assertIn('@Sharding', lowered_text)
self.check_wsc_in_lowered(lowered_text)
arr3 = jax.device_put(np_inp, NamedSharding(mesh, P('y', 'x')))
with self.assertRaisesRegex(
@ -5383,7 +5388,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
out = f(arr)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
self.assertIn('@Sharding', f.lower(arr).as_text())
self.check_wsc_in_lowered(f.lower(arr).as_text())
def g(x):
out = f(x)
@ -5414,7 +5419,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
out = f(arr)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
self.assertIn('@Sharding', f.lower(arr).as_text())
self.check_wsc_in_lowered(f.lower(arr).as_text())
self.assertArraysEqual(out, np.squeeze(np_inp, axis=2))
def g(x):
@ -5441,7 +5446,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
out = f(arr, ((2, 2, 0),), P('x'))
self.assertArraysEqual(out, np.pad(np_inp, 2))
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
self.assertIn('@Sharding', f.lower(arr, ((2, 2, 0),), P('x')).as_text())
self.check_wsc_in_lowered(f.lower(arr, ((2, 2, 0),), P('x')).as_text())
out = f(arr, ((0, 0, 0),), P('x'))
self.assertArraysEqual(out, np_inp)
@ -5489,7 +5494,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
out = f(arr1, arr2)
self.assertEqual(out.sharding, s)
self.assertArraysEqual(out, np.concatenate([arr1, arr2], axis=1))
self.assertIn('@Sharding', f.lower(arr1, arr2).as_text())
self.check_wsc_in_lowered(f.lower(arr1, arr2).as_text())
out = f(arr1, arr2, method='lax')
self.assertEqual(out.sharding, s)
@ -5568,8 +5573,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out1.sharding, NamedSharding(mesh, P('y')))
self.assertArraysEqual(out2, np.argmin(np_inp, axis=1))
self.assertEqual(out2.sharding, NamedSharding(mesh, P('x')))
self.assertIn('@Sharding', f.lower(arr).as_text())
self.check_wsc_in_lowered(f.lower(arr).as_text())
@jtu.with_user_mesh((2, 2), ('x', 'y'), {mesh_lib.AxisTypes.Auto: ('x', 'y')})
def test_only_auto(self, mesh):
@ -5618,7 +5622,10 @@ class ShardingInTypesTest(jtu.JaxTestCase):
out = f(arr, arr2)
self.assertEqual(out.sharding, NamedSharding(mesh2, P('x',)))
lowered_text = f.lower(arr, arr2).as_text()
self.assertTrue(lowered_text.count("unspecified_dims") == 5)
if config.use_shardy_partitioner.value:
self.assertTrue(lowered_text.count("{?}") == 5)
else:
self.assertTrue(lowered_text.count("unspecified_dims") == 5)
mesh3 = jtu.create_mesh((2, 2), ('x', 'y'),
axis_types={mesh_lib.AxisTypes.User: 'y',
@ -5629,7 +5636,12 @@ class ShardingInTypesTest(jtu.JaxTestCase):
out = f(arr, arr2)
self.assertEqual(out.sharding, NamedSharding(mesh3, P('x',)))
lowered_text = f.lower(arr, arr2).as_text()
self.assertTrue(lowered_text.count("unspecified_dims") == 4)
print(lowered_text)
if config.use_shardy_partitioner.value:
self.assertTrue(lowered_text.count("{?}") == 5)
self.assertIn('replicated={"y"}', lowered_text)
else:
self.assertTrue(lowered_text.count("unspecified_dims") == 4)
with self.assertRaisesRegex(
ValueError,
@ -5784,7 +5796,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
return ys
f(arr)
self.assertIn('@Sharding', f.lower(arr).as_text())
self.check_wsc_in_lowered(f.lower(arr).as_text())
with self.assertRaisesRegex(NotImplementedError, "split on sharded dims"):
f(arr, sizes=(1, 1), axis=1)
@ -5864,6 +5876,31 @@ class ShardingInTypesTest(jtu.JaxTestCase):
ValueError, "PartitionSpec cannot contain axis names.*Auto"):
g(arr1, arr2)
@jtu.with_user_mesh((2, 2, 2), ('x', 'y', 'z'),
axis_types={AxisTypes.User: ('x', 'y'),
AxisTypes.Auto: 'z'})
def test_out_sharding_mix_axis_types(self, mesh):
np_inp = np.arange(16).reshape(4, 2, 2)
s = NamedSharding(mesh, P('x', None, None))
arr = jax.device_put(np_inp, s)
@jax.jit
def f(x):
y = x * 2
self.assertEqual(y.sharding.spec, P('x', None, None))
return y
out = f(arr)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x',)))
self.assertArraysEqual(out, np_inp * 2)
lowered_text = f.lower(arr).as_text()
if config.use_shardy_partitioner.value:
self.assertTrue(lowered_text.count(
'[{"x"}, {?}, {?}], replicated={"y"}') == 3)
else:
self.assertTrue(lowered_text.count("unspecified_dims=[1,2]") == 3)
@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):