mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add unregistered mhlo.num_replicas and mhlo.num_partitions attributes to HLO output.
These are to allow PJRT plugin developers an inline way to determine the number of replicas/partitions to which the module is targeted. There are no stability guarantees on these attributes at the moment. PiperOrigin-RevId: 524013922
This commit is contained in:
parent
fdbad53b15
commit
2e524411db
@ -552,6 +552,8 @@ def lower_jaxpr_to_module(
|
||||
result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
|
||||
arg_names: Optional[Sequence[Optional[str]]] = None,
|
||||
result_names: Optional[Sequence[Optional[str]]] = None,
|
||||
num_replicas: int = 1,
|
||||
num_partitions: int = 1,
|
||||
) -> LoweringResult:
|
||||
"""Lowers a top-level jaxpr to an MLIR module.
|
||||
|
||||
@ -600,13 +602,11 @@ def lower_jaxpr_to_module(
|
||||
with ctx.context, ir.Location.unknown(ctx.context):
|
||||
# Remove module name characters that XLA would alter. This ensures that
|
||||
# XLA computation preserves the module name.
|
||||
attrs = ctx.module.operation.attributes
|
||||
module_name = _module_name_regex.sub("_", module_name)
|
||||
ctx.module.operation.attributes["sym_name"] = ir.StringAttr.get(
|
||||
module_name)
|
||||
unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects)
|
||||
if unlowerable_effects:
|
||||
raise ValueError(
|
||||
f'Cannot lower jaxpr with unlowerable effects: {unlowerable_effects}')
|
||||
attrs["sym_name"] = ir.StringAttr.get(module_name)
|
||||
attrs["mhlo.num_replicas"] = i32_attr(num_replicas)
|
||||
attrs["mhlo.num_partitions"] = i32_attr(num_partitions)
|
||||
lower_jaxpr_to_fun(
|
||||
ctx, "main", jaxpr, ordered_effects, public=True, create_tokens=True,
|
||||
replace_tokens_with_dummy=True,
|
||||
|
@ -889,7 +889,8 @@ def lower_parallel_callable(
|
||||
arg_shardings=None,
|
||||
result_shardings=None,
|
||||
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
|
||||
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths)
|
||||
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
|
||||
num_replicas=replicas.num_global_replicas)
|
||||
module, keepalive, host_callbacks = (
|
||||
lowering_result.module, lowering_result.keepalive,
|
||||
lowering_result.host_callbacks)
|
||||
@ -1901,6 +1902,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
out_op_shardings = map(_to_logical_op_sharding, global_out_avals, out_shardings)
|
||||
replicated_args = [False] * len(global_in_avals)
|
||||
axis_ctx = sharding_impls.ShardingContext(device_assignment)
|
||||
num_partitions = len(device_assignment)
|
||||
else:
|
||||
# This path is triggered for `jit(pmap)` cases.
|
||||
replicated_args = None
|
||||
@ -1908,6 +1910,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
out_op_shardings = None
|
||||
axis_env = sharding_impls.AxisEnv(nreps, (), ())
|
||||
axis_ctx = sharding_impls.ReplicaAxisContext(axis_env)
|
||||
num_partitions = 1
|
||||
|
||||
module_name = f"{api_name}_{fun_name}"
|
||||
|
||||
@ -1933,7 +1936,9 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
arg_shardings=in_op_shardings,
|
||||
result_shardings=out_op_shardings,
|
||||
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
|
||||
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths)
|
||||
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
|
||||
num_replicas=nreps,
|
||||
num_partitions=num_partitions)
|
||||
module, keepalive, host_callbacks = (
|
||||
lowering_result.module, lowering_result.keepalive,
|
||||
lowering_result.host_callbacks)
|
||||
@ -2224,6 +2229,8 @@ def lower_mesh_computation(
|
||||
out_partitions = map(_to_logical_op_sharding, global_out_avals, out_shardings)
|
||||
replicated_args = [False] * len(in_jaxpr_avals)
|
||||
axis_ctx = sharding_impls.SPMDAxisContext(mesh, manual_axes)
|
||||
num_replicas = 1
|
||||
num_partitions = mesh.devices.size
|
||||
else:
|
||||
replicated_args = [not get_array_mapping(i.spec) for i in in_shardings] # type: ignore
|
||||
in_partitions = None
|
||||
@ -2233,6 +2240,8 @@ def lower_mesh_computation(
|
||||
names=tuple(global_axis_sizes.keys()),
|
||||
sizes=tuple(global_axis_sizes.values()))
|
||||
axis_ctx = sharding_impls.ReplicaAxisContext(axis_env)
|
||||
num_replicas = mesh.devices.size
|
||||
num_partitions = 1
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
module: Union[str, xc.XlaComputation]
|
||||
module_name = f"{api_name}_{fun_name}"
|
||||
@ -2258,7 +2267,9 @@ def lower_mesh_computation(
|
||||
arg_shardings=in_partitions,
|
||||
result_shardings=out_partitions,
|
||||
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
|
||||
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths)
|
||||
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
|
||||
num_replicas=num_replicas,
|
||||
num_partitions=num_partitions)
|
||||
module, keepalive, host_callbacks = (
|
||||
lowering_result.module, lowering_result.keepalive,
|
||||
lowering_result.host_callbacks)
|
||||
|
@ -1051,7 +1051,12 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
|
||||
self.assertIsNotNone(f.compiler_ir(dialect="stablehlo"))
|
||||
|
||||
def test_jit_lower_no_prunning(self):
|
||||
def test_jit_replica_attributes(self):
|
||||
hlo = self.jit(lambda x: x + 4).lower(1.).as_text("stablehlo")
|
||||
self.assertIn("mhlo.num_partitions = 1", hlo)
|
||||
self.assertIn("mhlo.num_replicas = 1", hlo)
|
||||
|
||||
def test_jit_lower_no_pruning(self):
|
||||
compiled = self.jit(lambda x, y: x + y).lower(1., 2.).compile()
|
||||
self.assertEqual(compiled._executable._kept_var_idx, {0, 1})
|
||||
self.assertLen(compiled._executable.in_avals, 2)
|
||||
|
@ -970,6 +970,20 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='stablehlo'))
|
||||
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testLowerPartitionsAttribute(self):
|
||||
@partial(pjit,
|
||||
in_shardings=(P('x'), P('x')),
|
||||
out_shardings=None)
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
shape = (8, 8)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
hlo = f.lower(x, x + 1).as_text("stablehlo")
|
||||
self.assertIn("mhlo.num_replicas = 1", hlo)
|
||||
self.assertIn("mhlo.num_partitions = 2", hlo)
|
||||
|
||||
@jtu.ignore_warning(category=DeprecationWarning)
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
def testLowerCompileCompilerIR(self):
|
||||
|
@ -393,6 +393,16 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
x_shape = core.ShapedArray(x.shape, x.dtype)
|
||||
self.assertAllClose(f.lower(x_shape).compile()(x), f(x))
|
||||
|
||||
def testLowerHasReplicaAttributes(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
num_devices = jax.device_count()
|
||||
shape = (num_devices, 4)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
lowered = f.lower(x)
|
||||
hlo = lowered.as_text("stablehlo")
|
||||
self.assertIn(f"mhlo.num_replicas = {num_devices}", hlo)
|
||||
self.assertIn("mhlo.num_partitions = 1", hlo)
|
||||
|
||||
def testMean(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
|
||||
@ -1864,6 +1874,25 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
message=".*Using jit-of-pmap can lead to inefficient data movement"):
|
||||
x = foo(x)
|
||||
|
||||
@jtu.ignore_warning(
|
||||
message=".*Using jit-of-pmap can lead to inefficient data movement")
|
||||
def testJitOfPmapLowerHasReplicaAttributes(self):
|
||||
device_count = jax.device_count()
|
||||
|
||||
if device_count == 1 or config.jax_disable_jit:
|
||||
raise SkipTest("test requires at least two devices")
|
||||
|
||||
@jax.jit
|
||||
@jax.pmap
|
||||
def foo(x): return x + x
|
||||
|
||||
x = np.ones((2,2,2), dtype=np.float32)
|
||||
|
||||
hlo = foo.lower(x).as_text("stablehlo")
|
||||
self.assertIn(f"mhlo.num_replicas = {2}", hlo)
|
||||
self.assertIn("mhlo.num_partitions = 1", hlo)
|
||||
|
||||
|
||||
def testPsumZeroCotangents(self):
|
||||
# https://github.com/google/jax/issues/3651
|
||||
def loss(params, meta_params):
|
||||
|
@ -731,6 +731,19 @@ class XMapTest(XMapTestCase):
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='stablehlo'))
|
||||
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testLowerPartitionsAttribute(self):
|
||||
f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...],
|
||||
axis_resources={'i': 'x'})
|
||||
x = jnp.arange(4, dtype=jnp.float32).reshape((2, 2))
|
||||
hlo = f.lower(x).as_text(dialect='stablehlo')
|
||||
if config.experimental_xmap_spmd_lowering:
|
||||
self.assertIn("mhlo.num_partitions = 2", hlo)
|
||||
self.assertIn("mhlo.num_replicas = 1", hlo)
|
||||
else:
|
||||
self.assertIn("mhlo.num_partitions = 1", hlo)
|
||||
self.assertIn("mhlo.num_replicas = 2", hlo)
|
||||
|
||||
@jtu.ignore_warning(category=DeprecationWarning)
|
||||
def testLowerCompileCompilerIR(self):
|
||||
# TODO(frostig): remove (deprecated)
|
||||
|
Loading…
x
Reference in New Issue
Block a user