[export] Adapt several collective lowering rules for multi-platform lowering

This fixes a few more places where the lowering rules used module_context.platform,
which is not supported for multi-platform lowering.
This commit is contained in:
George Necula 2023-10-06 09:36:14 +02:00
parent e088a8e36b
commit a59ada03bd
4 changed files with 144 additions and 77 deletions

View File

@ -1505,9 +1505,13 @@ def lower_multi_platform(ctx: LoweringRuleContext,
rule_args: the args of the lowering rules.
rule_kwargs: the kwargs of the lowering rules.
"""
assert isinstance(ctx.module_context.lowering_parameters.platforms, tuple)
platforms = ctx.module_context.lowering_parameters.platforms
platforms_with_specific_rules = util.flatten(
platforms: Sequence[str]
if ctx.module_context.lowering_parameters.is_multi_platform:
assert ctx.module_context.lowering_parameters.platforms is not None
platforms = ctx.module_context.lowering_parameters.platforms
else:
platforms = (ctx.module_context.platform,)
platforms_with_specific_rules: Sequence[str] = util.flatten(
[ps for ps, _ in rules if ps is not None])
platforms_with_default_rule = [p for p in platforms
if p not in platforms_with_specific_rules]
@ -1517,7 +1521,7 @@ def lower_multi_platform(ctx: LoweringRuleContext,
rule_index = len(kept_rules)
if ps is not None:
# Keep only rules that mention the platforms of interest
interesting_ps = [p for p in platforms if p in ps]
interesting_ps = [p for p in platforms if p in ps] # type: ignore
if interesting_ps:
for p in interesting_ps:
assert p not in platform_to_kept_rules_idx

View File

@ -1352,7 +1352,6 @@ def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs):
else:
raise TypeError(aval)
def _extend_axis_env(env: sharding_impls.AxisEnv, name, size: int):
return sharding_impls.AxisEnv(env.nreps, env.names + (name,),
env.sizes + (size,))

View File

@ -725,10 +725,15 @@ def _allreduce_abstract_eval(*args, axes, axis_index_groups):
for arg, named_shape in zip(args, named_shapes)]
def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
if axis_index_groups is not None and ctx.module_context.platform == "tpu":
# TODO(necula): clean this up when we have module_context.platforms
if ctx.module_context.lowering_parameters.is_multi_platform:
for_tpu = ("tpu" in ctx.module_context.lowering_parameters.platforms)
else:
for_tpu = (ctx.module_context.platform == "tpu")
if axis_index_groups is not None and for_tpu:
len_0 = len(axis_index_groups[0])
if any(len(g) != len_0 for g in axis_index_groups):
raise ValueError("axis_index_groups must all be the same size")
raise ValueError("axis_index_groups must all be the same size for TPU lowering")
named_axes, positional_axes = axes_partition = [], []
for axis in axes:
axes_partition[isinstance(axis, int)].append(axis)
@ -1175,7 +1180,8 @@ def _all_gather_impl(x, *, all_gather_dimension, axis_name, axis_index_groups, a
raise AssertionError("Unexpected call to _all_gather_impl")
def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
axis_index_groups, axis_size, tiled):
axis_index_groups, axis_size, tiled,
platform=None):
# TODO(jekbradbury): enable for all_gather_dimension > 0
x_aval, = ctx.avals_in
out_aval, = ctx.avals_out
@ -1184,9 +1190,8 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
axis_context,
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
)
if (ctx.module_context.platform == 'tpu' or
ctx.module_context.platform in ('cuda', 'rocm')
and all_gather_dimension == 0):
if (platform == 'tpu' or
(platform in ('cuda', 'rocm') and all_gather_dimension == 0)):
if not tiled:
new_shape = list(x_aval.shape)
new_shape.insert(all_gather_dimension, 1)
@ -1282,6 +1287,10 @@ all_gather_p = core.AxisPrimitive('all_gather')
all_gather_p.def_abstract_eval(_all_gather_abstract_eval)
all_gather_p.def_impl(_all_gather_impl)
mlir.register_lowering(all_gather_p, _all_gather_lowering)
for p in ("cuda", "rocm", "tpu"):
mlir.register_lowering(all_gather_p,
partial(_all_gather_lowering, platform=p),
platform=p)
ad.deflinear2(all_gather_p, _all_gather_transpose_rule)
batching.primitive_batchers[all_gather_p] = _all_gather_batcher
batching.axis_primitive_batchers[all_gather_p] = _all_gather_batched_collective
@ -1313,63 +1322,68 @@ def _reduce_scatter_via_reducer(x, *, reducer, scatter_dimension, axis_name,
return outs
def _reduce_scatter_lowering(prim, reducer, ctx, x,
*, scatter_dimension, axis_name,
axis_index_groups, axis_size, tiled):
if ctx.module_context.platform in ("tpu", "cuda", "rocm"):
x_aval, = ctx.avals_in
aval_out, = ctx.avals_out
scalar_aval = x_aval.update(shape=())
replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name,
axis_index_groups)
scatter_out_shape = list(x_aval.shape)
scatter_out_shape[scatter_dimension] //= axis_size
axis_context = ctx.module_context.axis_context
is_spmd = isinstance(
axis_context,
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
)
if is_spmd:
# We want to emit the all-gather with global device IDs and a unique
# channel ID, as otherwise it interprets the devices as replicas instead
# of partitions - and XLA is configured with only a single replica.
channel = ctx.module_context.new_channel()
other_args = dict(
channel_handle=hlo.ChannelHandle.get(
channel, mlir.DEVICE_TO_DEVICE_TYPE),
use_global_device_ids=ir.BoolAttr.get(True))
else:
other_args = {}
op = hlo.ReduceScatterOp(
mlir.aval_to_ir_type(x_aval.update(shape=scatter_out_shape)),
x,
scatter_dimension=mlir.i64_attr(scatter_dimension),
replica_groups=_replica_groups_hlo(replica_groups),
**other_args)
scalar_type = mlir.aval_to_ir_type(scalar_aval)
reducer_block = op.regions[0].blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer_block):
lower_reducer = mlir.lower_fun(prim.bind, multiple_results=False)
reducer_ctx = ctx.replace(primitive=None,
avals_in=[scalar_aval] * 2,
avals_out=[scalar_aval])
out_nodes = lower_reducer(
reducer_ctx, *([a] for a in reducer_block.arguments))
hlo.ReturnOp(util.flatten(out_nodes))
if tiled:
return op.results
else:
return hlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), op.result).results
def _reduce_scatter_lowering(
prim, reducer, ctx, x,
*, scatter_dimension, axis_name,
axis_index_groups, axis_size, tiled):
x_aval, = ctx.avals_in
aval_out, = ctx.avals_out
scalar_aval = x_aval.update(shape=())
replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name,
axis_index_groups)
scatter_out_shape = list(x_aval.shape)
scatter_out_shape[scatter_dimension] //= axis_size
axis_context = ctx.module_context.axis_context
is_spmd = isinstance(
axis_context,
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
)
if is_spmd:
# We want to emit the all-gather with global device IDs and a unique
# channel ID, as otherwise it interprets the devices as replicas instead
# of partitions - and XLA is configured with only a single replica.
channel = ctx.module_context.new_channel()
other_args = dict(
channel_handle=hlo.ChannelHandle.get(
channel, mlir.DEVICE_TO_DEVICE_TYPE),
use_global_device_ids=ir.BoolAttr.get(True))
else:
return mlir.lower_fun(_reduce_scatter_via_reducer, multiple_results=False)(
ctx, x,
reducer=reducer,
scatter_dimension=scatter_dimension,
axis_name=axis_name,
axis_index_groups=axis_index_groups,
axis_size=axis_size,
tiled=tiled)
other_args = {}
op = hlo.ReduceScatterOp(
mlir.aval_to_ir_type(x_aval.update(shape=scatter_out_shape)),
x,
scatter_dimension=mlir.i64_attr(scatter_dimension),
replica_groups=_replica_groups_hlo(replica_groups),
**other_args)
scalar_type = mlir.aval_to_ir_type(scalar_aval)
reducer_block = op.regions[0].blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer_block):
lower_reducer = mlir.lower_fun(prim.bind, multiple_results=False)
reducer_ctx = ctx.replace(primitive=None,
avals_in=[scalar_aval] * 2,
avals_out=[scalar_aval])
out_nodes = lower_reducer(
reducer_ctx, *([a] for a in reducer_block.arguments))
hlo.ReturnOp(util.flatten(out_nodes))
if tiled:
return op.results
else:
return hlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), op.result).results
def _reduce_scatter_lowering_via_reducer(
prim, reducer, ctx, x,
*, scatter_dimension, axis_name,
axis_index_groups, axis_size, tiled):
return mlir.lower_fun(_reduce_scatter_via_reducer, multiple_results=False)(
ctx, x,
reducer=reducer,
scatter_dimension=scatter_dimension,
axis_name=axis_name,
axis_index_groups=axis_index_groups,
axis_size=axis_size,
tiled=tiled)
def _reduce_scatter_abstract_eval(x, *, axis_name, scatter_dimension,
@ -1449,9 +1463,17 @@ reduce_scatter_p.def_abstract_eval(_reduce_scatter_abstract_eval)
ad.deflinear2(reduce_scatter_p, _reduce_scatter_transpose_rule)
batching.primitive_batchers[reduce_scatter_p] = _reduce_scatter_batcher
batching.axis_primitive_batchers[reduce_scatter_p] = _reduce_scatter_collective
mlir.register_lowering(
reduce_scatter_p,
partial(_reduce_scatter_lowering, lax.add_p, psum))
partial(_reduce_scatter_lowering_via_reducer, lax.add_p, psum))
reduce_scatter_lowering_for_psum = partial(_reduce_scatter_lowering,
lax.add_p, psum)
for p in ("tpu", "cuda", "rocm"):
mlir.register_lowering(
reduce_scatter_p, reduce_scatter_lowering_for_psum,
platform=p)
core.axis_substitution_rules[reduce_scatter_p] = \
partial(_subst_all_names_in_param, 'axis_name')

View File

@ -13,12 +13,18 @@
# limitations under the License.
"""Tests for multi-platform and cross-platform JAX export."""
import math
import re
from typing import Literal
from typing import Callable, Sequence
from absl import logging
from absl.testing import absltest
import numpy as np
import jax
from jax import lax
from jax._src import pjit
from jax._src import test_util as jtu
from jax.experimental.export import export
# TODO(necula): Move the primitive harness out of jax2tf so that we can move
@ -46,7 +52,6 @@ _skip_cuda_lowering_unless_have_gpus = make_disjunction_regexp(
"random_",
)
class PrimitiveTest(jtu.JaxTestCase):
@classmethod
@ -88,8 +93,21 @@ class PrimitiveTest(jtu.JaxTestCase):
for l in harness.jax_unimplemented:
if l.filter(dtype=harness.dtype):
unimplemented_platforms = unimplemented_platforms.union(l.devices)
if (_skip_cuda_lowering_unless_have_gpus.search(harness.fullname)
and all(d.platform != "gpu" for d in self.devices)):
unimplemented_platforms.add("gpu")
logging.info("Harness is not implemented on %s", unimplemented_platforms)
self.export_and_compare_to_native(
func_jax, *args,
unimplemented_platforms=unimplemented_platforms)
def export_and_compare_to_native(
self, func_jax: Callable,
*args: jax.Array,
unimplemented_platforms: set[str] = set(),
skip_run_on_platforms: set[str] = set()):
devices = [
d
for d in self.__class__.devices
@ -99,14 +117,9 @@ class PrimitiveTest(jtu.JaxTestCase):
# lowering_platforms uses "cuda" instead of "gpu"
lowering_platforms: list[str] = [
p if p != "gpu" else "cuda"
for p in {"cpu", "gpu", "tpu"} - unimplemented_platforms
for p in ("cpu", "gpu", "tpu")
if p not in unimplemented_platforms
]
if (
"cuda" in lowering_platforms
and _skip_cuda_lowering_unless_have_gpus.search(harness.fullname)
and all(d.platform != "gpu" for d in devices)
):
lowering_platforms.remove("cuda")
if len(lowering_platforms) <= 1:
self.skipTest(
@ -117,6 +130,9 @@ class PrimitiveTest(jtu.JaxTestCase):
exp = export.export(func_jax, lowering_platforms=lowering_platforms)(*args)
for device in devices:
if device.platform in skip_run_on_platforms:
logging.info("Skipping running on %s", device)
continue
device_args = jax.tree_util.tree_map(
lambda x: jax.device_put(x, device), args
)
@ -127,6 +143,32 @@ class PrimitiveTest(jtu.JaxTestCase):
self.assertAllClose(native_res, exported_res)
# TODO(necula): Check HLO equivalence for the ultimate test.
def test_psum_scatter(self):
f = jax.jit(jax.pmap(lambda x: lax.psum_scatter(x, 'i'),
axis_name='i',
devices=jax.devices()[:1]))
shape = (1, 1, 8)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
self.export_and_compare_to_native(f, x)
# The lowering rule for all_gather has special cases for bool.
@jtu.parameterized_filterable(
kwargs=[
dict(dtype=dtype)
for dtype in [np.bool_, np.float32]],
)
def test_all_gather(self, *, dtype):
f = jax.jit(jax.pmap(lambda x: lax.all_gather(x, 'i'),
axis_name='i',
devices=jax.devices()[:1]))
shape = (1, 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
if dtype == np.bool_:
x = (x % 2).astype(np.bool_)
self.export_and_compare_to_native(f, x)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())