Make custom partitioning work without a mesh context manager. If the arguments have NamedSharding on them, then inside partition function, we should get NamedSharding without the existence of the mesh context manager

PiperOrigin-RevId: 662146686
This commit is contained in:
Yash Katariya 2024-08-12 10:39:58 -07:00 committed by jax authors
parent 60bf5b7727
commit 53045380b1
6 changed files with 64 additions and 14 deletions

View File

@ -24,6 +24,7 @@ import inspect
from typing import Any
import weakref
import numpy as np
import jax
from jax import tree_util
from jax._src import api_util
@ -481,17 +482,20 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
infer_sharding_from_operands,
decode_shardings,
static_args):
mesh = mesh_lib.thread_resources.env.physical_mesh
axis_context = ctx.module_context.axis_context
if (isinstance(axis_context, sharding_impls.SPMDAxisContext) and
set(axis_context.manual_axes) == set(axis_context.mesh.axis_names)):
return mlir.lower_fun(core.jaxpr_as_fun(call), multiple_results=True)(ctx, *values)
mesh = mesh_lib.thread_resources.env.physical_mesh
if isinstance(axis_context, sharding_impls.ShardingContext):
devices = axis_context.device_assignment
if devices is None:
raise AssertionError(
'Please file a bug at https://github.com/google/jax/issues')
if axis_context.mesh_shape is not None:
ma, ms = list(zip(*axis_context.mesh_shape))
mesh = mesh_lib.Mesh(np.array(devices).reshape(ms), ma)
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
devices = axis_context.mesh._flat_devices_tuple
else:

View File

@ -204,7 +204,7 @@ def jaxpr_has_primitive(jaxpr: core.Jaxpr, prim_name: str) -> bool:
# stablehlo is oblivious of physical devices.
prim_requires_devices_during_lowering: set[core.Primitive] = set()
def jaxpr_has_prim_requiring_devices(jaxpr: core.Jaxpr):
def jaxpr_has_prim_requiring_devices(jaxpr: core.Jaxpr) -> bool:
for eqn in jaxpr.eqns:
if eqn.primitive in prim_requires_devices_during_lowering:
return True

View File

@ -1033,7 +1033,6 @@ def lower_jaxpr_to_module(
input_output_aliases: None | tuple[int | None, ...] = None,
propagated_out_mem_kinds: tuple[None | str, ...] | None = None,
lowering_parameters: LoweringParameters,
mesh_shape_tuple: tuple[tuple[str, int], ...] | None = None,
) -> LoweringResult:
"""Lowers a top-level jaxpr to an MLIR module.
@ -1121,13 +1120,14 @@ def lower_jaxpr_to_module(
# XLA computation preserves the module name.
attrs = ctx.module.operation.attributes
if config.use_shardy_partitioner.value:
assert mesh_shape_tuple is not None
assert (isinstance(axis_context, sharding_impls.ShardingContext) and
axis_context.mesh_shape is not None)
ctx.module.body.append(
dialects.sdy.MeshOp(
"mesh",
dialects.sdy.MeshAttr.get(
[dialects.sdy.MeshAxisAttr.get(name, size)
for name, size in mesh_shape_tuple])))
for name, size in axis_context.mesh_shape])))
module_name = _module_name_regex.sub("_", module_name)
attrs["sym_name"] = ir.StringAttr.get(module_name)
attrs["mhlo.num_replicas"] = i32_attr(num_replicas)

View File

@ -1881,7 +1881,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
propagated_out_mem_kinds: tuple[None | str, ...],
platforms: tuple[str, ...],
lowering_parameters: mlir.LoweringParameters,
mesh_shape_tuple: tuple[tuple[str, int], ...]):
mesh_shape_tuple: tuple[tuple[str, int], ...] | None):
jaxpr = closed_jaxpr.jaxpr
in_shardings = semantic_in_shardings.shardings
out_shardings = semantic_out_shardings.shardings
@ -1911,7 +1911,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
in_mlir_shardings = map(_to_logical_sharding, global_in_avals, in_shardings)
out_mlir_shardings = map(_to_logical_sharding, global_out_avals, out_shardings)
replicated_args = [False] * len(global_in_avals)
axis_ctx = sharding_impls.ShardingContext(num_devices, device_assignment)
axis_ctx = sharding_impls.ShardingContext(num_devices, device_assignment,
mesh_shape_tuple)
num_partitions = num_devices
else:
# This path is triggered for `jit(pmap)` cases.
@ -1957,8 +1958,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
all_default_mem_kind=all_default_mem_kind,
input_output_aliases=inout_aliases,
propagated_out_mem_kinds=propagated_out_mem_kinds,
lowering_parameters=lowering_parameters,
mesh_shape_tuple=mesh_shape_tuple)
lowering_parameters=lowering_parameters)
tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform)
unordered_effects = list(
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
@ -2202,15 +2202,21 @@ def lower_sharding_computation(
in_shardings, global_in_avals) # type: ignore
semantic_out_shardings = SemanticallyEqualShardings(
out_shardings, global_out_avals) # type: ignore
prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr)
mesh_shape_tuple = None
if config.use_shardy_partitioner.value:
for sharding in it.chain(
in_shardings, out_shardings,
if config.use_shardy_partitioner.value or prim_requires_devices:
for sharding in it.chain(in_shardings, out_shardings,
[js for js, _ in unique_intermediate_shardings]):
if isinstance(sharding, sharding_impls.NamedSharding):
if (mesh_shape_tuple is not None and
mesh_shape_tuple != sharding.mesh.shape_tuple):
raise ValueError(
"mesh should be the same across the entire program. Got mesh"
f" shape for one sharding {mesh_shape_tuple} and"
f" {sharding.mesh.shape_tuple} for another")
mesh_shape_tuple = sharding.mesh.shape_tuple
break
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(

View File

@ -1162,6 +1162,7 @@ class ShardingContext:
"""
num_devices: int
device_assignment: tuple[xc.Device, ...] | None = None
mesh_shape: tuple[tuple[str, int], ...] | None = None
def __post_init__(self):
if self.device_assignment is not None:

View File

@ -1514,6 +1514,45 @@ class CustomPartitionerTest(jtu.JaxTestCase):
xs = jnp.ones([32, 16])
self.assertEqual(pjit_f(xs), xs.sum())
def test_custom_partitioning_no_mesh_context(self):
self.skip_if_custom_partitioning_not_supported()
@custom_partitioning
def f(x):
return x
def partition(mesh, arg_shapes, result_shape):
def lower_fn(x):
@jax.jit
def g(y):
return y
return g(x)
x_shard = arg_shapes[0].sharding
return (
mesh,
lower_fn,
NamedSharding(x_shard.mesh, P('x')),
(NamedSharding(x_shard.mesh, P('x')),),
)
def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
x_shard = arg_shapes[0].sharding
return NamedSharding(x_shard.mesh, P('x'))
f.def_partition(
infer_sharding_from_operands=infer_sharding_from_operands,
partition=partition,
)
mesh = jtu.create_global_mesh((4,), ('x',))
x = np.asarray(np.random.randint(0, 20, (32,)), dtype=np.float32)
s = NamedSharding(mesh, P('x'))
jit_f = jax.jit(f, in_shardings=s, out_shardings=s)
self.assertArraysEqual(x, jit_f(x))
@jtu.pytest_mark_if_available('multiaccelerator')
class AutoShardingPjitTest(jtu.JaxTestCase):