mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Make custom partitioning work without a mesh context manager. If the arguments have NamedSharding on them, then inside partition
function, we should get NamedSharding without the existence of the mesh context manager
PiperOrigin-RevId: 662146686
This commit is contained in:
parent
60bf5b7727
commit
53045380b1
@ -24,6 +24,7 @@ import inspect
|
||||
from typing import Any
|
||||
import weakref
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax import tree_util
|
||||
from jax._src import api_util
|
||||
@ -481,17 +482,20 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
|
||||
infer_sharding_from_operands,
|
||||
decode_shardings,
|
||||
static_args):
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
axis_context = ctx.module_context.axis_context
|
||||
if (isinstance(axis_context, sharding_impls.SPMDAxisContext) and
|
||||
set(axis_context.manual_axes) == set(axis_context.mesh.axis_names)):
|
||||
return mlir.lower_fun(core.jaxpr_as_fun(call), multiple_results=True)(ctx, *values)
|
||||
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
if isinstance(axis_context, sharding_impls.ShardingContext):
|
||||
devices = axis_context.device_assignment
|
||||
if devices is None:
|
||||
raise AssertionError(
|
||||
'Please file a bug at https://github.com/google/jax/issues')
|
||||
if axis_context.mesh_shape is not None:
|
||||
ma, ms = list(zip(*axis_context.mesh_shape))
|
||||
mesh = mesh_lib.Mesh(np.array(devices).reshape(ms), ma)
|
||||
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
|
||||
devices = axis_context.mesh._flat_devices_tuple
|
||||
else:
|
||||
|
@ -204,7 +204,7 @@ def jaxpr_has_primitive(jaxpr: core.Jaxpr, prim_name: str) -> bool:
|
||||
# stablehlo is oblivious of physical devices.
|
||||
prim_requires_devices_during_lowering: set[core.Primitive] = set()
|
||||
|
||||
def jaxpr_has_prim_requiring_devices(jaxpr: core.Jaxpr):
|
||||
def jaxpr_has_prim_requiring_devices(jaxpr: core.Jaxpr) -> bool:
|
||||
for eqn in jaxpr.eqns:
|
||||
if eqn.primitive in prim_requires_devices_during_lowering:
|
||||
return True
|
||||
|
@ -1033,7 +1033,6 @@ def lower_jaxpr_to_module(
|
||||
input_output_aliases: None | tuple[int | None, ...] = None,
|
||||
propagated_out_mem_kinds: tuple[None | str, ...] | None = None,
|
||||
lowering_parameters: LoweringParameters,
|
||||
mesh_shape_tuple: tuple[tuple[str, int], ...] | None = None,
|
||||
) -> LoweringResult:
|
||||
"""Lowers a top-level jaxpr to an MLIR module.
|
||||
|
||||
@ -1121,13 +1120,14 @@ def lower_jaxpr_to_module(
|
||||
# XLA computation preserves the module name.
|
||||
attrs = ctx.module.operation.attributes
|
||||
if config.use_shardy_partitioner.value:
|
||||
assert mesh_shape_tuple is not None
|
||||
assert (isinstance(axis_context, sharding_impls.ShardingContext) and
|
||||
axis_context.mesh_shape is not None)
|
||||
ctx.module.body.append(
|
||||
dialects.sdy.MeshOp(
|
||||
"mesh",
|
||||
dialects.sdy.MeshAttr.get(
|
||||
[dialects.sdy.MeshAxisAttr.get(name, size)
|
||||
for name, size in mesh_shape_tuple])))
|
||||
for name, size in axis_context.mesh_shape])))
|
||||
module_name = _module_name_regex.sub("_", module_name)
|
||||
attrs["sym_name"] = ir.StringAttr.get(module_name)
|
||||
attrs["mhlo.num_replicas"] = i32_attr(num_replicas)
|
||||
|
@ -1881,7 +1881,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
propagated_out_mem_kinds: tuple[None | str, ...],
|
||||
platforms: tuple[str, ...],
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
mesh_shape_tuple: tuple[tuple[str, int], ...]):
|
||||
mesh_shape_tuple: tuple[tuple[str, int], ...] | None):
|
||||
jaxpr = closed_jaxpr.jaxpr
|
||||
in_shardings = semantic_in_shardings.shardings
|
||||
out_shardings = semantic_out_shardings.shardings
|
||||
@ -1911,7 +1911,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
in_mlir_shardings = map(_to_logical_sharding, global_in_avals, in_shardings)
|
||||
out_mlir_shardings = map(_to_logical_sharding, global_out_avals, out_shardings)
|
||||
replicated_args = [False] * len(global_in_avals)
|
||||
axis_ctx = sharding_impls.ShardingContext(num_devices, device_assignment)
|
||||
axis_ctx = sharding_impls.ShardingContext(num_devices, device_assignment,
|
||||
mesh_shape_tuple)
|
||||
num_partitions = num_devices
|
||||
else:
|
||||
# This path is triggered for `jit(pmap)` cases.
|
||||
@ -1957,8 +1958,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
all_default_mem_kind=all_default_mem_kind,
|
||||
input_output_aliases=inout_aliases,
|
||||
propagated_out_mem_kinds=propagated_out_mem_kinds,
|
||||
lowering_parameters=lowering_parameters,
|
||||
mesh_shape_tuple=mesh_shape_tuple)
|
||||
lowering_parameters=lowering_parameters)
|
||||
tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform)
|
||||
unordered_effects = list(
|
||||
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
|
||||
@ -2202,15 +2202,21 @@ def lower_sharding_computation(
|
||||
in_shardings, global_in_avals) # type: ignore
|
||||
semantic_out_shardings = SemanticallyEqualShardings(
|
||||
out_shardings, global_out_avals) # type: ignore
|
||||
|
||||
prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr)
|
||||
|
||||
mesh_shape_tuple = None
|
||||
if config.use_shardy_partitioner.value:
|
||||
for sharding in it.chain(
|
||||
in_shardings, out_shardings,
|
||||
if config.use_shardy_partitioner.value or prim_requires_devices:
|
||||
for sharding in it.chain(in_shardings, out_shardings,
|
||||
[js for js, _ in unique_intermediate_shardings]):
|
||||
if isinstance(sharding, sharding_impls.NamedSharding):
|
||||
if (mesh_shape_tuple is not None and
|
||||
mesh_shape_tuple != sharding.mesh.shape_tuple):
|
||||
raise ValueError(
|
||||
"mesh should be the same across the entire program. Got mesh"
|
||||
f" shape for one sharding {mesh_shape_tuple} and"
|
||||
f" {sharding.mesh.shape_tuple} for another")
|
||||
mesh_shape_tuple = sharding.mesh.shape_tuple
|
||||
break
|
||||
|
||||
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
|
||||
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
|
||||
|
@ -1162,6 +1162,7 @@ class ShardingContext:
|
||||
"""
|
||||
num_devices: int
|
||||
device_assignment: tuple[xc.Device, ...] | None = None
|
||||
mesh_shape: tuple[tuple[str, int], ...] | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.device_assignment is not None:
|
||||
|
@ -1514,6 +1514,45 @@ class CustomPartitionerTest(jtu.JaxTestCase):
|
||||
xs = jnp.ones([32, 16])
|
||||
self.assertEqual(pjit_f(xs), xs.sum())
|
||||
|
||||
def test_custom_partitioning_no_mesh_context(self):
|
||||
self.skip_if_custom_partitioning_not_supported()
|
||||
|
||||
@custom_partitioning
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
def partition(mesh, arg_shapes, result_shape):
|
||||
def lower_fn(x):
|
||||
@jax.jit
|
||||
def g(y):
|
||||
return y
|
||||
|
||||
return g(x)
|
||||
|
||||
x_shard = arg_shapes[0].sharding
|
||||
return (
|
||||
mesh,
|
||||
lower_fn,
|
||||
NamedSharding(x_shard.mesh, P('x')),
|
||||
(NamedSharding(x_shard.mesh, P('x')),),
|
||||
)
|
||||
|
||||
def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
|
||||
x_shard = arg_shapes[0].sharding
|
||||
return NamedSharding(x_shard.mesh, P('x'))
|
||||
|
||||
f.def_partition(
|
||||
infer_sharding_from_operands=infer_sharding_from_operands,
|
||||
partition=partition,
|
||||
)
|
||||
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
x = np.asarray(np.random.randint(0, 20, (32,)), dtype=np.float32)
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
|
||||
jit_f = jax.jit(f, in_shardings=s, out_shardings=s)
|
||||
self.assertArraysEqual(x, jit_f(x))
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user