Add unregistered mhlo.num_replicas and mhlo.num_partitions attributes to HLO output.

These are to allow PJRT plugin developers an inline way to determine the number of replicas/partitions to which the module is targeted. There are no stability guarantees on these attributes at the moment.

PiperOrigin-RevId: 524013922
This commit is contained in:
Peter Hawkins 2023-04-13 08:55:01 -07:00 committed by jax authors
parent fdbad53b15
commit 2e524411db
6 changed files with 82 additions and 10 deletions

View File

@ -552,6 +552,8 @@ def lower_jaxpr_to_module(
result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
arg_names: Optional[Sequence[Optional[str]]] = None,
result_names: Optional[Sequence[Optional[str]]] = None,
num_replicas: int = 1,
num_partitions: int = 1,
) -> LoweringResult:
"""Lowers a top-level jaxpr to an MLIR module.
@ -600,13 +602,11 @@ def lower_jaxpr_to_module(
with ctx.context, ir.Location.unknown(ctx.context):
# Remove module name characters that XLA would alter. This ensures that
# XLA computation preserves the module name.
attrs = ctx.module.operation.attributes
module_name = _module_name_regex.sub("_", module_name)
ctx.module.operation.attributes["sym_name"] = ir.StringAttr.get(
module_name)
unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects)
if unlowerable_effects:
raise ValueError(
f'Cannot lower jaxpr with unlowerable effects: {unlowerable_effects}')
attrs["sym_name"] = ir.StringAttr.get(module_name)
attrs["mhlo.num_replicas"] = i32_attr(num_replicas)
attrs["mhlo.num_partitions"] = i32_attr(num_partitions)
lower_jaxpr_to_fun(
ctx, "main", jaxpr, ordered_effects, public=True, create_tokens=True,
replace_tokens_with_dummy=True,

View File

@ -889,7 +889,8 @@ def lower_parallel_callable(
arg_shardings=None,
result_shardings=None,
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths)
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
num_replicas=replicas.num_global_replicas)
module, keepalive, host_callbacks = (
lowering_result.module, lowering_result.keepalive,
lowering_result.host_callbacks)
@ -1901,6 +1902,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
out_op_shardings = map(_to_logical_op_sharding, global_out_avals, out_shardings)
replicated_args = [False] * len(global_in_avals)
axis_ctx = sharding_impls.ShardingContext(device_assignment)
num_partitions = len(device_assignment)
else:
# This path is triggered for `jit(pmap)` cases.
replicated_args = None
@ -1908,6 +1910,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
out_op_shardings = None
axis_env = sharding_impls.AxisEnv(nreps, (), ())
axis_ctx = sharding_impls.ReplicaAxisContext(axis_env)
num_partitions = 1
module_name = f"{api_name}_{fun_name}"
@ -1933,7 +1936,9 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
arg_shardings=in_op_shardings,
result_shardings=out_op_shardings,
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths)
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
num_replicas=nreps,
num_partitions=num_partitions)
module, keepalive, host_callbacks = (
lowering_result.module, lowering_result.keepalive,
lowering_result.host_callbacks)
@ -2224,6 +2229,8 @@ def lower_mesh_computation(
out_partitions = map(_to_logical_op_sharding, global_out_avals, out_shardings)
replicated_args = [False] * len(in_jaxpr_avals)
axis_ctx = sharding_impls.SPMDAxisContext(mesh, manual_axes)
num_replicas = 1
num_partitions = mesh.devices.size
else:
replicated_args = [not get_array_mapping(i.spec) for i in in_shardings] # type: ignore
in_partitions = None
@ -2233,6 +2240,8 @@ def lower_mesh_computation(
names=tuple(global_axis_sizes.keys()),
sizes=tuple(global_axis_sizes.values()))
axis_ctx = sharding_impls.ReplicaAxisContext(axis_env)
num_replicas = mesh.devices.size
num_partitions = 1
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
module: Union[str, xc.XlaComputation]
module_name = f"{api_name}_{fun_name}"
@ -2258,7 +2267,9 @@ def lower_mesh_computation(
arg_shardings=in_partitions,
result_shardings=out_partitions,
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths)
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
num_replicas=num_replicas,
num_partitions=num_partitions)
module, keepalive, host_callbacks = (
lowering_result.module, lowering_result.keepalive,
lowering_result.host_callbacks)

View File

@ -1051,7 +1051,12 @@ class CPPJitTest(jtu.BufferDonationTestCase):
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
self.assertIsNotNone(f.compiler_ir(dialect="stablehlo"))
def test_jit_lower_no_prunning(self):
def test_jit_replica_attributes(self):
hlo = self.jit(lambda x: x + 4).lower(1.).as_text("stablehlo")
self.assertIn("mhlo.num_partitions = 1", hlo)
self.assertIn("mhlo.num_replicas = 1", hlo)
def test_jit_lower_no_pruning(self):
compiled = self.jit(lambda x, y: x + y).lower(1., 2.).compile()
self.assertEqual(compiled._executable._kept_var_idx, {0, 1})
self.assertLen(compiled._executable.in_avals, 2)

View File

@ -970,6 +970,20 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
self.assertIsNotNone(f.compiler_ir(dialect='stablehlo'))
@jtu.with_mesh([('x', 2)])
def testLowerPartitionsAttribute(self):
@partial(pjit,
in_shardings=(P('x'), P('x')),
out_shardings=None)
def f(x, y):
return x + y
shape = (8, 8)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
hlo = f.lower(x, x + 1).as_text("stablehlo")
self.assertIn("mhlo.num_replicas = 1", hlo)
self.assertIn("mhlo.num_partitions = 2", hlo)
@jtu.ignore_warning(category=DeprecationWarning)
@jtu.with_mesh([('x', 2), ('y', 2)])
def testLowerCompileCompilerIR(self):

View File

@ -393,6 +393,16 @@ class PythonPmapTest(jtu.JaxTestCase):
x_shape = core.ShapedArray(x.shape, x.dtype)
self.assertAllClose(f.lower(x_shape).compile()(x), f(x))
def testLowerHasReplicaAttributes(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
num_devices = jax.device_count()
shape = (num_devices, 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
lowered = f.lower(x)
hlo = lowered.as_text("stablehlo")
self.assertIn(f"mhlo.num_replicas = {num_devices}", hlo)
self.assertIn("mhlo.num_partitions = 1", hlo)
def testMean(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
@ -1864,6 +1874,25 @@ class PythonPmapTest(jtu.JaxTestCase):
message=".*Using jit-of-pmap can lead to inefficient data movement"):
x = foo(x)
@jtu.ignore_warning(
message=".*Using jit-of-pmap can lead to inefficient data movement")
def testJitOfPmapLowerHasReplicaAttributes(self):
device_count = jax.device_count()
if device_count == 1 or config.jax_disable_jit:
raise SkipTest("test requires at least two devices")
@jax.jit
@jax.pmap
def foo(x): return x + x
x = np.ones((2,2,2), dtype=np.float32)
hlo = foo.lower(x).as_text("stablehlo")
self.assertIn(f"mhlo.num_replicas = {2}", hlo)
self.assertIn("mhlo.num_partitions = 1", hlo)
def testPsumZeroCotangents(self):
# https://github.com/google/jax/issues/3651
def loss(params, meta_params):

View File

@ -731,6 +731,19 @@ class XMapTest(XMapTestCase):
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
self.assertIsNotNone(f.compiler_ir(dialect='stablehlo'))
@jtu.with_mesh([('x', 2)])
def testLowerPartitionsAttribute(self):
f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...],
axis_resources={'i': 'x'})
x = jnp.arange(4, dtype=jnp.float32).reshape((2, 2))
hlo = f.lower(x).as_text(dialect='stablehlo')
if config.experimental_xmap_spmd_lowering:
self.assertIn("mhlo.num_partitions = 2", hlo)
self.assertIn("mhlo.num_replicas = 1", hlo)
else:
self.assertIn("mhlo.num_partitions = 1", hlo)
self.assertIn("mhlo.num_replicas = 2", hlo)
@jtu.ignore_warning(category=DeprecationWarning)
def testLowerCompileCompilerIR(self):
# TODO(frostig): remove (deprecated)