mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Make sharding_in_types work with Shardy
PiperOrigin-RevId: 713479962
This commit is contained in:
parent
fb832afc00
commit
b2b38679e2
@ -49,7 +49,8 @@ from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.layout import AutoLayout, DeviceLocalLayout
|
||||
from jax._src.sharding import Sharding as JSharding
|
||||
from jax._src.sharding_impls import AUTO, NamedSharding
|
||||
from jax._src.sharding_impls import (AUTO, NamedSharding,
|
||||
modify_sdy_sharding_wrt_axis_types)
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib.mlir import dialects, ir, passmanager
|
||||
@ -1689,13 +1690,17 @@ def lower_jaxpr_to_fun(
|
||||
for o, s, o_aval in zip(flat_outputs, ir_result_shardings, output_avals)]
|
||||
|
||||
if ir_result_shardings is not None:
|
||||
flat_outputs = [
|
||||
wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s,
|
||||
unspecified_dims=us[2])
|
||||
if us[0] and not us[1] else o
|
||||
for o, s, o_aval, us in zip(flat_outputs, ir_result_shardings,
|
||||
output_avals, unconstrained_shardings) # type: ignore
|
||||
]
|
||||
temp_flat_outputs = []
|
||||
for o, s, o_aval, us in zip(flat_outputs, ir_result_shardings,
|
||||
output_avals, unconstrained_shardings): # type: ignore
|
||||
if us[0] and not us[1]:
|
||||
if config.use_shardy_partitioner.value and config.sharding_in_types.value:
|
||||
s = modify_sdy_sharding_wrt_axis_types(s, o_aval.sharding.mesh)
|
||||
temp_flat_outputs.append(wrap_with_sharding_op(
|
||||
entry_lowering_ctx, o, o_aval, s, unspecified_dims=us[2]))
|
||||
else:
|
||||
temp_flat_outputs.append(o)
|
||||
flat_outputs = temp_flat_outputs
|
||||
|
||||
# Insert a custom call if output is on host because XLA needs that to do the
|
||||
# transfer.
|
||||
@ -2594,14 +2599,20 @@ def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None):
|
||||
return op
|
||||
# TODO(yashkatariya): If all the axes in pspec are AUTO or collective,
|
||||
# `return op` early and avoid bloating HLO size.
|
||||
proto = (aval.sharding._to_xla_hlo_sharding(aval.ndim).to_proto()
|
||||
if sharding_proto is None else sharding_proto)
|
||||
unspecified_dims = None
|
||||
if aval.sharding.mesh._any_axis_collective:
|
||||
unspecified_dims = set(range(aval.ndim))
|
||||
elif aval.sharding.mesh._any_axis_auto:
|
||||
unspecified_dims = {i for i, s in enumerate(aval.sharding.spec) if s is None}
|
||||
return wrap_with_sharding_op(ctx, op, aval, proto, unspecified_dims)
|
||||
if config.use_shardy_partitioner.value:
|
||||
proto = (aval.sharding._to_sdy_sharding(aval.ndim)
|
||||
if sharding_proto is None else sharding_proto)
|
||||
proto = modify_sdy_sharding_wrt_axis_types(proto, aval.sharding.mesh)
|
||||
return wrap_with_sharding_op(ctx, op, aval, proto)
|
||||
else:
|
||||
proto = (aval.sharding._to_xla_hlo_sharding(aval.ndim).to_proto()
|
||||
if sharding_proto is None else sharding_proto)
|
||||
unspecified_dims = None
|
||||
if aval.sharding.mesh._any_axis_auto:
|
||||
# TODO(yashkatariya): Maybe if any mesh axis is auto, mark all axes
|
||||
# as unspecified?
|
||||
unspecified_dims = {i for i, s in enumerate(aval.sharding.spec) if s is None}
|
||||
return wrap_with_sharding_op(ctx, op, aval, proto, unspecified_dims)
|
||||
|
||||
|
||||
def set_sharding(op, sharding: xc.OpSharding | sharding_impls.SdyArraySharding):
|
||||
|
@ -2163,12 +2163,9 @@ def _concretize_abstract_out_shardings(shardings, avals, device_assignment):
|
||||
out = []
|
||||
for s, a in zip(shardings, avals):
|
||||
if isinstance(s, UnspecifiedValue) and a.sharding is not None:
|
||||
if config.use_shardy_partitioner.value:
|
||||
spec = a.sharding.spec
|
||||
else:
|
||||
spec = (PartitionSpec(*[PartitionSpec.UNCONSTRAINED if sp is None else sp
|
||||
for sp in a.sharding.spec])
|
||||
if a.sharding.mesh._any_axis_auto else a.sharding.spec)
|
||||
spec = (PartitionSpec(*[PartitionSpec.UNCONSTRAINED if sp is None else sp
|
||||
for sp in a.sharding.spec])
|
||||
if a.sharding.mesh._any_axis_auto else a.sharding.spec)
|
||||
out.append(NamedSharding(
|
||||
_abstract_to_concrete_mesh(a.sharding.mesh), spec))
|
||||
else:
|
||||
|
@ -2467,8 +2467,7 @@ def multi_sharding_in_dim(ctx, ops, in_avals, out_aval):
|
||||
if in_aval.sharding == out_aval.sharding or in_aval.sharding is None:
|
||||
out.append(op)
|
||||
else:
|
||||
proto = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto()
|
||||
out.append(mlir.lower_sharding_under_shit(ctx, op, out_aval, proto))
|
||||
out.append(mlir.lower_sharding_under_shit(ctx, op, out_aval))
|
||||
return out
|
||||
|
||||
|
||||
|
@ -2710,7 +2710,9 @@ ad.deflinear2(sharding_cast_p, _sharding_cast_transpose_rule)
|
||||
def _sharding_cast_hlo_lowering(ctx, x_node, *, src_sharding, dst_sharding):
|
||||
aval, = ctx.avals_in
|
||||
aval_out, = ctx.avals_out
|
||||
proto = dst_sharding._to_xla_hlo_sharding(aval.ndim).to_proto()
|
||||
proto = (dst_sharding._to_sdy_sharding(aval.ndim)
|
||||
if config.use_shardy_partitioner.value else
|
||||
dst_sharding._to_xla_hlo_sharding(aval.ndim).to_proto())
|
||||
return [mlir.lower_sharding_under_shit(ctx, x_node, aval_out, proto)]
|
||||
mlir.register_lowering(sharding_cast_p, _sharding_cast_hlo_lowering)
|
||||
|
||||
|
@ -142,6 +142,7 @@ class SdyArraySharding:
|
||||
mesh_shape: tuple[tuple[str, int], ...] | None
|
||||
dimension_shardings: Sequence[SdyDimSharding]
|
||||
logical_device_ids: tuple[int, ...] | None = None
|
||||
replicated_axes: tuple[str, ...] = ()
|
||||
|
||||
# NOTE: An MLIR context is required as a context manager.
|
||||
def build(self) -> sdy.TensorShardingAttr:
|
||||
@ -155,14 +156,17 @@ class SdyArraySharding:
|
||||
ldi)
|
||||
return sdy.TensorShardingAttr.get(
|
||||
mesh_attr,
|
||||
[dim_sharding.build() for dim_sharding in self.dimension_shardings])
|
||||
[dim_sharding.build() for dim_sharding in self.dimension_shardings],
|
||||
replicated_axes=[sdy.AxisRefAttr.get(axis) for axis in self.replicated_axes])
|
||||
|
||||
def __repr__(self):
|
||||
dim_sharding_repr = ', '.join(
|
||||
d._custom_repr() for d in self.dimension_shardings)
|
||||
device_id_repr = (f', device_ids={self.logical_device_ids}'
|
||||
if self.logical_device_ids is not None else '')
|
||||
return f"SdyArraySharding([{dim_sharding_repr}]{device_id_repr})"
|
||||
rar = (f', replicated_axes={self.replicated_axes}'
|
||||
if self.replicated_axes else '')
|
||||
return f"SdyArraySharding([{dim_sharding_repr}]{device_id_repr}{rar})"
|
||||
|
||||
|
||||
@util.cache(max_size=4096, trace_context_in_key=False)
|
||||
@ -425,6 +429,23 @@ class NamedSharding(jsharding.Sharding):
|
||||
return SdyArraySharding(self.mesh.shape_tuple, dim_shardings,
|
||||
self._logical_device_ids)
|
||||
|
||||
# TODO(yashkatariya): Upstream this into `_to_sdy_sharding` maybe with an extra
|
||||
# parameter to it `_to_sdy_sharding(self, ndim, modify_wrt_axis_types=False)`
|
||||
def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArraySharding, mesh):
|
||||
if mesh._any_axis_auto:
|
||||
dim_shardings, used_axes = [], [] # type: ignore
|
||||
for d in sdy_sharding.dimension_shardings:
|
||||
# TODO(yashkatariya): Maybe if any mesh axis is auto, mark all axes as open?
|
||||
dim_shardings.append(SdyDimSharding(axes=[], is_closed=False)
|
||||
if not d.axes and d.is_closed else d)
|
||||
used_axes.extend(d.axes)
|
||||
remaining_axes = set(mesh.axis_names) - set(used_axes)
|
||||
replicated_axes = tuple(r for r in remaining_axes
|
||||
if mesh._name_to_type[r] == mesh_lib.AxisTypes.User)
|
||||
return SdyArraySharding(sdy_sharding.mesh_shape, dim_shardings,
|
||||
sdy_sharding.logical_device_ids, replicated_axes)
|
||||
return sdy_sharding
|
||||
|
||||
|
||||
@util.cache(max_size=128, trace_context_in_key=False)
|
||||
def get_replicated_hlo_sharding():
|
||||
|
@ -4729,9 +4729,14 @@ def spec_regex(s):
|
||||
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
|
||||
|
||||
|
||||
@jtu.with_config(jax_use_shardy_partitioner=False)
|
||||
class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
def check_wsc_in_lowered(self, text):
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.assertIn('sdy.sharding_constraint', text)
|
||||
else:
|
||||
self.assertIn('@Sharding', text)
|
||||
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_basic_mul(self, mesh):
|
||||
np_inp = np.arange(16.).reshape(8, 2)
|
||||
@ -4753,7 +4758,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
lowered_text = f.lower(arr).as_text()
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.assertIn('sdy.sharding_constraint', lowered_text)
|
||||
self.assertEqual(lowered_text.count('sdy.sharding_constraint'), 3)
|
||||
else:
|
||||
self.assertEqual(lowered_text.count('@Sharding'), 3)
|
||||
|
||||
@ -4834,7 +4839,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, out_spec))
|
||||
|
||||
lowered = f.lower(arr1, arr2)
|
||||
self.assertIn('@Sharding', lowered.as_text())
|
||||
self.check_wsc_in_lowered(lowered.as_text())
|
||||
|
||||
compiled_text = lowered.compile().as_text()
|
||||
if collective_name is not None and compiled_text is not None:
|
||||
@ -4971,7 +4976,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, out_spec))
|
||||
|
||||
lowered = f.lower(arr)
|
||||
self.assertIn('@Sharding', lowered.as_text())
|
||||
self.check_wsc_in_lowered(lowered.as_text())
|
||||
|
||||
compiled_text = lowered.compile().as_text()
|
||||
if reduce and compiled_text is not None:
|
||||
@ -5002,7 +5007,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, out_spec))
|
||||
|
||||
lowered = f.lower(arr)
|
||||
self.assertIn('@Sharding', lowered.as_text())
|
||||
self.check_wsc_in_lowered(lowered.as_text())
|
||||
|
||||
compiled_text = lowered.compile().as_text()
|
||||
if reduce and compiled_text is not None:
|
||||
@ -5044,7 +5049,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, out_spec))
|
||||
|
||||
lowered_text = f.lower(arr).as_text()
|
||||
self.assertIn('@Sharding', lowered_text)
|
||||
self.check_wsc_in_lowered(lowered_text)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('2', 2),
|
||||
@ -5068,7 +5073,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out, np_inp ** pow)
|
||||
|
||||
lowered_text = f.lower(arr).as_text()
|
||||
self.assertIn('@Sharding', lowered_text)
|
||||
self.check_wsc_in_lowered(lowered_text)
|
||||
|
||||
@jtu.with_user_mesh((1,), 'x')
|
||||
def test_broadcasting_nary_error(self, mesh):
|
||||
@ -5102,7 +5107,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.sharding, s)
|
||||
|
||||
lowered_text = f.lower(arr).as_text()
|
||||
self.assertIn('@Sharding', lowered_text)
|
||||
self.check_wsc_in_lowered(lowered_text)
|
||||
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_jnp_array(self, mesh):
|
||||
@ -5137,7 +5142,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'z', 'x')))
|
||||
|
||||
lowered_text = f.lower(arr).as_text()
|
||||
self.assertIn('@Sharding', lowered_text)
|
||||
self.check_wsc_in_lowered(lowered_text)
|
||||
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_broadcasted_iota_with_sharding(self, mesh):
|
||||
@ -5182,7 +5187,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
|
||||
|
||||
lowered_text = f.lower(arr1, arr2).as_text()
|
||||
self.assertIn('@Sharding', lowered_text)
|
||||
self.check_wsc_in_lowered(lowered_text)
|
||||
|
||||
@jax.jit
|
||||
def g(x, y):
|
||||
@ -5228,7 +5233,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None, 'y', None)))
|
||||
|
||||
lowered_text = h.lower(arr1, arr2).as_text()
|
||||
self.assertIn('@Sharding', lowered_text)
|
||||
self.check_wsc_in_lowered(lowered_text)
|
||||
|
||||
@jax.jit
|
||||
def h2(x, y):
|
||||
@ -5268,7 +5273,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out, np_inp.reshape(dst_shape) * 2)
|
||||
|
||||
lowered_text = f.lower(arr, new_s).as_text()
|
||||
self.assertIn('@Sharding', lowered_text)
|
||||
self.check_wsc_in_lowered(lowered_text)
|
||||
|
||||
def g(x):
|
||||
out = f(x, new_s)
|
||||
@ -5295,7 +5300,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out, arr1)
|
||||
|
||||
lowered_text = f.lower(arr1 == arr2, arr1, arr2).as_text()
|
||||
self.assertIn('@Sharding', lowered_text)
|
||||
self.check_wsc_in_lowered(lowered_text)
|
||||
|
||||
arr3 = jax.device_put(np_inp, NamedSharding(mesh, P('y', 'x')))
|
||||
with self.assertRaisesRegex(
|
||||
@ -5383,7 +5388,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
out = f(arr)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
|
||||
self.assertIn('@Sharding', f.lower(arr).as_text())
|
||||
self.check_wsc_in_lowered(f.lower(arr).as_text())
|
||||
|
||||
def g(x):
|
||||
out = f(x)
|
||||
@ -5414,7 +5419,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
out = f(arr)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
|
||||
self.assertIn('@Sharding', f.lower(arr).as_text())
|
||||
self.check_wsc_in_lowered(f.lower(arr).as_text())
|
||||
self.assertArraysEqual(out, np.squeeze(np_inp, axis=2))
|
||||
|
||||
def g(x):
|
||||
@ -5441,7 +5446,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
out = f(arr, ((2, 2, 0),), P('x'))
|
||||
self.assertArraysEqual(out, np.pad(np_inp, 2))
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
|
||||
self.assertIn('@Sharding', f.lower(arr, ((2, 2, 0),), P('x')).as_text())
|
||||
self.check_wsc_in_lowered(f.lower(arr, ((2, 2, 0),), P('x')).as_text())
|
||||
|
||||
out = f(arr, ((0, 0, 0),), P('x'))
|
||||
self.assertArraysEqual(out, np_inp)
|
||||
@ -5489,7 +5494,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
out = f(arr1, arr2)
|
||||
self.assertEqual(out.sharding, s)
|
||||
self.assertArraysEqual(out, np.concatenate([arr1, arr2], axis=1))
|
||||
self.assertIn('@Sharding', f.lower(arr1, arr2).as_text())
|
||||
self.check_wsc_in_lowered(f.lower(arr1, arr2).as_text())
|
||||
|
||||
out = f(arr1, arr2, method='lax')
|
||||
self.assertEqual(out.sharding, s)
|
||||
@ -5568,8 +5573,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out1.sharding, NamedSharding(mesh, P('y')))
|
||||
self.assertArraysEqual(out2, np.argmin(np_inp, axis=1))
|
||||
self.assertEqual(out2.sharding, NamedSharding(mesh, P('x')))
|
||||
|
||||
self.assertIn('@Sharding', f.lower(arr).as_text())
|
||||
self.check_wsc_in_lowered(f.lower(arr).as_text())
|
||||
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'), {mesh_lib.AxisTypes.Auto: ('x', 'y')})
|
||||
def test_only_auto(self, mesh):
|
||||
@ -5618,7 +5622,10 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
out = f(arr, arr2)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh2, P('x',)))
|
||||
lowered_text = f.lower(arr, arr2).as_text()
|
||||
self.assertTrue(lowered_text.count("unspecified_dims") == 5)
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.assertTrue(lowered_text.count("{?}") == 5)
|
||||
else:
|
||||
self.assertTrue(lowered_text.count("unspecified_dims") == 5)
|
||||
|
||||
mesh3 = jtu.create_mesh((2, 2), ('x', 'y'),
|
||||
axis_types={mesh_lib.AxisTypes.User: 'y',
|
||||
@ -5629,7 +5636,12 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
out = f(arr, arr2)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh3, P('x',)))
|
||||
lowered_text = f.lower(arr, arr2).as_text()
|
||||
self.assertTrue(lowered_text.count("unspecified_dims") == 4)
|
||||
print(lowered_text)
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.assertTrue(lowered_text.count("{?}") == 5)
|
||||
self.assertIn('replicated={"y"}', lowered_text)
|
||||
else:
|
||||
self.assertTrue(lowered_text.count("unspecified_dims") == 4)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
@ -5784,7 +5796,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
return ys
|
||||
|
||||
f(arr)
|
||||
self.assertIn('@Sharding', f.lower(arr).as_text())
|
||||
self.check_wsc_in_lowered(f.lower(arr).as_text())
|
||||
|
||||
with self.assertRaisesRegex(NotImplementedError, "split on sharded dims"):
|
||||
f(arr, sizes=(1, 1), axis=1)
|
||||
@ -5864,6 +5876,31 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
ValueError, "PartitionSpec cannot contain axis names.*Auto"):
|
||||
g(arr1, arr2)
|
||||
|
||||
@jtu.with_user_mesh((2, 2, 2), ('x', 'y', 'z'),
|
||||
axis_types={AxisTypes.User: ('x', 'y'),
|
||||
AxisTypes.Auto: 'z'})
|
||||
def test_out_sharding_mix_axis_types(self, mesh):
|
||||
np_inp = np.arange(16).reshape(4, 2, 2)
|
||||
s = NamedSharding(mesh, P('x', None, None))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = x * 2
|
||||
self.assertEqual(y.sharding.spec, P('x', None, None))
|
||||
return y
|
||||
|
||||
out = f(arr)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x',)))
|
||||
self.assertArraysEqual(out, np_inp * 2)
|
||||
|
||||
lowered_text = f.lower(arr).as_text()
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.assertTrue(lowered_text.count(
|
||||
'[{"x"}, {?}, {?}], replicated={"y"}') == 3)
|
||||
else:
|
||||
self.assertTrue(lowered_text.count("unspecified_dims=[1,2]") == 3)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class PJitErrorTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user