mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
#sdy add JAX Shardy support for shard_map.
For example the following JAX program: ```py devices = np.array(jax.devices()[:8]) mesh = Mesh(devices, axis_names=('x')) a = jax.device_put( jnp.arange(8 * 8).reshape((8, 8)), jax.sharding.NamedSharding(mesh, P('x', None))) @jax.jit @partial( shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None) ) def fwd(a): axis_size = lax.psum(1, 'x') perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] return lax.ppermute(a, 'x', perm=perm) print(jax.jit(fwd).lower(a).as_text()) ``` prints: ```cpp module @jit_fwd attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { sdy.mesh @mesh = <["x"=8]> func.func public @main(%arg0: tensor<8x8xi32> {mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<8x8xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) { %0 = call @fwd(%arg0) : (tensor<8x8xi32>) -> tensor<8x8xi32> return %0 : tensor<8x8xi32> } func.func private @fwd(%arg0: tensor<8x8xi32> {mhlo.layout_mode = "default"}) -> (tensor<8x8xi32> {mhlo.layout_mode = "default"}) { %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x"}, {}]>] out_shardings=[<@mesh, [{"x"}, {}]>] manual_axes={"x"} (%arg1: tensor<1x8xi32>) { %1 = "stablehlo.collective_permute"(%arg1) <{channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, source_target_pairs = dense<[[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 0]]> : tensor<8x2xi64>}> : (tensor<1x8xi32>) -> tensor<1x8xi32> sdy.return %1 : tensor<1x8xi32> } : (tensor<8x8xi32>) -> tensor<8x8xi32> return %0 : tensor<8x8xi32> } } ``` PiperOrigin-RevId: 679165100
This commit is contained in:
parent
7b53c2f39d
commit
e62a50cd34
@ -51,6 +51,8 @@ from jax._src.api import _shared_code_pmap, _prepare_pmap
|
||||
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
|
||||
windowed_reductions, convolution, fft, linalg,
|
||||
special, control_flow, ann)
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo, sdy
|
||||
from jax._src.util import (HashableFunction, HashablePartial, unzip2, unzip3,
|
||||
as_hashable_function, memoize, partition_list,
|
||||
merge_lists, split_list, subs_list2)
|
||||
@ -643,9 +645,71 @@ def _rule_missing(prim: core.Primitive, *_, **__):
|
||||
|
||||
# Lowering
|
||||
|
||||
def _shardy_shard_map_sharding(
|
||||
ctx: mlir.LoweringRuleContext, mesh, names, aval_in
|
||||
) -> ir.Attribute:
|
||||
axes = {name: i for i, ns in names.items() for name in ns}
|
||||
ns = _make_scoped_manual_sharding(ctx, mesh, axes)
|
||||
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
|
||||
ns = sharding_impls.physical_sharding(aval_in, ns)
|
||||
aval_in = core.physical_aval(aval_in)
|
||||
return ns._to_sdy_sharding(aval_in.ndim).build()
|
||||
|
||||
|
||||
def _shard_map_lowering_shardy(
|
||||
ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto):
|
||||
in_avals_ = [v.aval for v in jaxpr.invars]
|
||||
if isinstance(ctx.module_context.axis_context, sharding_impls.SPMDAxisContext):
|
||||
# Nested `ManualComputationOp`s cannot refer to axes that are already
|
||||
# manual. So figure out what axes are free thus far and get the new axis
|
||||
# context.
|
||||
free_axis = frozenset(mesh.axis_names) - ctx.module_context.axis_context.manual_axes
|
||||
new_axis_context = sharding_impls.SPMDAxisContext(mesh, free_axis - auto)
|
||||
else:
|
||||
new_axis_context = sharding_impls.SPMDAxisContext(
|
||||
mesh, frozenset(mesh.axis_names) - auto)
|
||||
sub_ctx = ctx.module_context.replace(axis_context=new_axis_context)
|
||||
args = (*ctx.dim_var_values, *in_nodes)
|
||||
|
||||
manual_axes = sub_ctx.axis_context.manual_axes
|
||||
mesh_shape = mesh.shape
|
||||
manual_axes_size = np.prod([mesh_shape[a] for a in manual_axes])
|
||||
if manual_axes_size == 1:
|
||||
# No need for a `ManualComputationOp` if all manual axes are size 1.
|
||||
out_nodes, _ = mlir.jaxpr_subcomp(
|
||||
sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *args,
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
return out_nodes
|
||||
|
||||
in_shardings = sdy.TensorShardingPerValueAttr.get(map(
|
||||
partial(_shardy_shard_map_sharding, ctx, mesh),
|
||||
in_names, ctx.avals_in))
|
||||
out_shardings = sdy.TensorShardingPerValueAttr.get(map(
|
||||
partial(_shardy_shard_map_sharding, ctx, mesh),
|
||||
out_names, ctx.avals_out))
|
||||
output_types = map(mlir.aval_to_ir_type, ctx.avals_out)
|
||||
manual_computation_op = sdy.ManualComputationOp(
|
||||
output_types, args, in_shardings, out_shardings,
|
||||
sdy.ManualAxesAttr.get(
|
||||
ir.ArrayAttr.get([ir.StringAttr.get(i) for i in manual_axes])))
|
||||
block = ir.Block.create_at_start(
|
||||
manual_computation_op.body, map(mlir.aval_to_ir_type, in_avals_))
|
||||
with ir.InsertionPoint(block), core.extend_axis_env_nd(
|
||||
tuple(mesh.shape.items())):
|
||||
out_nodes_, _ = mlir.jaxpr_subcomp(
|
||||
sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *block.arguments,
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
sdy.ReturnOp([ir.Value(x) for x in out_nodes_])
|
||||
|
||||
return manual_computation_op.results
|
||||
|
||||
|
||||
def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names,
|
||||
check_rep, rewrite, auto):
|
||||
del check_rep, rewrite
|
||||
if config.use_shardy_partitioner.value:
|
||||
return _shard_map_lowering_shardy(
|
||||
ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto)
|
||||
in_avals_ = [v.aval for v in jaxpr.invars]
|
||||
out_avals_ = [x.aval for x in jaxpr.outvars]
|
||||
in_nodes_ = map(partial(_xla_shard, ctx, mesh, auto), in_names, ctx.avals_in,
|
||||
|
@ -1346,6 +1346,11 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "shard_map_test",
|
||||
srcs = ["shard_map_test.py"],
|
||||
enable_configs = [
|
||||
"gpu_2gpu_shardy",
|
||||
"tpu_v3_2x2_shardy",
|
||||
"tpu_v4_2x2_shardy",
|
||||
],
|
||||
shard_count = {
|
||||
"cpu": 50,
|
||||
"gpu": 10,
|
||||
|
@ -848,6 +848,10 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
@parameterized.parameters([True, False])
|
||||
@jtu.run_on_devices('cpu', 'gpu', 'tpu')
|
||||
def test_debug_print_jit(self, jit):
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest(
|
||||
'TODO(b/364547005): debug prints not supported by Shardy yet'
|
||||
)
|
||||
mesh = Mesh(jax.devices(), ('i',))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
|
||||
@ -1229,13 +1233,18 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
return x
|
||||
|
||||
hlo_str = mlir.module_to_string(jax.jit(foo).lower(x).compiler_ir('stablehlo'))
|
||||
self.assertIn("call @shmap_body", hlo_str)
|
||||
self.assertIn("call @shmap_body_0", hlo_str)
|
||||
self.assertIn("%arg0: tensor<1xf32>", hlo_str)
|
||||
self.assertIn("\"[None]\"", hlo_str)
|
||||
self.assertIn("%arg1: tensor<1xf32>", hlo_str)
|
||||
self.assertIn("\"[('i',)]\"", hlo_str)
|
||||
self.assertIn("-> (tensor<1xf32> {jax.result_info = \"[('i',)]\"})", hlo_str)
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.assertEqual(2, hlo_str.count('sdy.manual_computation'))
|
||||
else:
|
||||
self.assertIn('call @shmap_body', hlo_str)
|
||||
self.assertIn('call @shmap_body_0', hlo_str)
|
||||
self.assertIn('%arg0: tensor<1xf32>', hlo_str)
|
||||
self.assertIn('"[None]"', hlo_str)
|
||||
self.assertIn('%arg1: tensor<1xf32>', hlo_str)
|
||||
self.assertIn('"[(\'i\',)]"', hlo_str)
|
||||
self.assertIn(
|
||||
'-> (tensor<1xf32> {jax.result_info = "[(\'i\',)]"})', hlo_str
|
||||
)
|
||||
|
||||
def test_rewrite_process_call(self):
|
||||
def f(x):
|
||||
@ -1759,10 +1768,18 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
v = jnp.arange(32.).reshape(4, 8)
|
||||
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
|
||||
self.assertIn(
|
||||
'sharding={devices=[1,1,2,2]<=[4] last_tile_dims={manual, replicated}}',
|
||||
f.lower(v).as_text('hlo'),
|
||||
)
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.assertIn(
|
||||
'in_shardings=[<@mesh, [{"i"}, {}]>] out_shardings=[<@mesh, [{"i"},'
|
||||
' {}]>] manual_axes={"i"}',
|
||||
f.lower(v).as_text(),
|
||||
)
|
||||
else:
|
||||
self.assertIn(
|
||||
'sharding={devices=[1,1,2,2]<=[4] last_tile_dims={manual,'
|
||||
' replicated}}',
|
||||
f.lower(v).as_text('hlo'),
|
||||
)
|
||||
self.assertAllClose(v*v, f(v), check_dtypes=False)
|
||||
|
||||
def test_sharded_prng_with_abstract_mesh(self):
|
||||
@ -1909,6 +1926,11 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(jax.jit(f)(), jnp.zeros((2,)))
|
||||
|
||||
def test_partial_auto_of_pjit_different_mesh(self):
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest(
|
||||
'Shardy requires the mesh axis names to be the same across '
|
||||
'the entire computation.'
|
||||
)
|
||||
mesh = jtu.create_mesh((2, 2), ('i', 'j'))
|
||||
mesh2 = jax.sharding.Mesh(mesh.devices, ('k', 'l'))
|
||||
|
||||
@ -1977,10 +1999,14 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
xs = jnp.arange(16.)
|
||||
|
||||
ir = jax.jit(jax.grad(lambda x: f(x).sum())).lower(xs)
|
||||
self.assertIn(
|
||||
'{jax.result_info = "[(\'i\', \'j\', \'k\', \'a\')]"}',
|
||||
ir.as_text()
|
||||
)
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.assertIn(
|
||||
'out_shardings=[<@mesh, [{"i", "j", "k", "a"}]>]', ir.as_text()
|
||||
)
|
||||
else:
|
||||
self.assertIn(
|
||||
"{jax.result_info = \"[('i', 'j', 'k', 'a')]\"}", ir.as_text()
|
||||
)
|
||||
|
||||
def test_vmap_spmd_axis_name_error(self):
|
||||
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
|
||||
@ -2609,5 +2635,27 @@ class CustomPartitionerTest(jtu.JaxTestCase):
|
||||
self.assertEqual(c.addressable_data(0).shape, (4, 2))
|
||||
|
||||
|
||||
@jtu.with_config(jax_use_shardy_partitioner=True)
|
||||
class SdyIntegrationTest(jtu.JaxTestCase):
|
||||
# Verify we can lower to a `ManualComputationOp`.
|
||||
def test_shardy_collective_permute(self):
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
a = jax.device_put(
|
||||
jnp.arange(8 * 8).reshape((8, 8)),
|
||||
jax.sharding.NamedSharding(mesh, P('x', None)),
|
||||
)
|
||||
|
||||
@jax.jit
|
||||
@partial(
|
||||
shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None)
|
||||
)
|
||||
def fwd(a):
|
||||
axis_size = lax.psum(1, 'x')
|
||||
perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
|
||||
return lax.ppermute(a, 'x', perm=perm)
|
||||
|
||||
self.assertIn('sdy.manual_computation', jax.jit(fwd).lower(a).as_text())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user