mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[export] Adapt several collective lowering rules for multi-platform lowering
This fixes a few more places where the lowering rules used module_context.platform, which is not supported for multi-platform lowering.
This commit is contained in:
parent
e088a8e36b
commit
a59ada03bd
@ -1505,9 +1505,13 @@ def lower_multi_platform(ctx: LoweringRuleContext,
|
||||
rule_args: the args of the lowering rules.
|
||||
rule_kwargs: the kwargs of the lowering rules.
|
||||
"""
|
||||
assert isinstance(ctx.module_context.lowering_parameters.platforms, tuple)
|
||||
platforms = ctx.module_context.lowering_parameters.platforms
|
||||
platforms_with_specific_rules = util.flatten(
|
||||
platforms: Sequence[str]
|
||||
if ctx.module_context.lowering_parameters.is_multi_platform:
|
||||
assert ctx.module_context.lowering_parameters.platforms is not None
|
||||
platforms = ctx.module_context.lowering_parameters.platforms
|
||||
else:
|
||||
platforms = (ctx.module_context.platform,)
|
||||
platforms_with_specific_rules: Sequence[str] = util.flatten(
|
||||
[ps for ps, _ in rules if ps is not None])
|
||||
platforms_with_default_rule = [p for p in platforms
|
||||
if p not in platforms_with_specific_rules]
|
||||
@ -1517,7 +1521,7 @@ def lower_multi_platform(ctx: LoweringRuleContext,
|
||||
rule_index = len(kept_rules)
|
||||
if ps is not None:
|
||||
# Keep only rules that mention the platforms of interest
|
||||
interesting_ps = [p for p in platforms if p in ps]
|
||||
interesting_ps = [p for p in platforms if p in ps] # type: ignore
|
||||
if interesting_ps:
|
||||
for p in interesting_ps:
|
||||
assert p not in platform_to_kept_rules_idx
|
||||
|
@ -1352,7 +1352,6 @@ def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs):
|
||||
else:
|
||||
raise TypeError(aval)
|
||||
|
||||
|
||||
def _extend_axis_env(env: sharding_impls.AxisEnv, name, size: int):
|
||||
return sharding_impls.AxisEnv(env.nreps, env.names + (name,),
|
||||
env.sizes + (size,))
|
||||
|
@ -725,10 +725,15 @@ def _allreduce_abstract_eval(*args, axes, axis_index_groups):
|
||||
for arg, named_shape in zip(args, named_shapes)]
|
||||
|
||||
def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
|
||||
if axis_index_groups is not None and ctx.module_context.platform == "tpu":
|
||||
# TODO(necula): clean this up when we have module_context.platforms
|
||||
if ctx.module_context.lowering_parameters.is_multi_platform:
|
||||
for_tpu = ("tpu" in ctx.module_context.lowering_parameters.platforms)
|
||||
else:
|
||||
for_tpu = (ctx.module_context.platform == "tpu")
|
||||
if axis_index_groups is not None and for_tpu:
|
||||
len_0 = len(axis_index_groups[0])
|
||||
if any(len(g) != len_0 for g in axis_index_groups):
|
||||
raise ValueError("axis_index_groups must all be the same size")
|
||||
raise ValueError("axis_index_groups must all be the same size for TPU lowering")
|
||||
named_axes, positional_axes = axes_partition = [], []
|
||||
for axis in axes:
|
||||
axes_partition[isinstance(axis, int)].append(axis)
|
||||
@ -1175,7 +1180,8 @@ def _all_gather_impl(x, *, all_gather_dimension, axis_name, axis_index_groups, a
|
||||
raise AssertionError("Unexpected call to _all_gather_impl")
|
||||
|
||||
def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
|
||||
axis_index_groups, axis_size, tiled):
|
||||
axis_index_groups, axis_size, tiled,
|
||||
platform=None):
|
||||
# TODO(jekbradbury): enable for all_gather_dimension > 0
|
||||
x_aval, = ctx.avals_in
|
||||
out_aval, = ctx.avals_out
|
||||
@ -1184,9 +1190,8 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
|
||||
axis_context,
|
||||
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
|
||||
)
|
||||
if (ctx.module_context.platform == 'tpu' or
|
||||
ctx.module_context.platform in ('cuda', 'rocm')
|
||||
and all_gather_dimension == 0):
|
||||
if (platform == 'tpu' or
|
||||
(platform in ('cuda', 'rocm') and all_gather_dimension == 0)):
|
||||
if not tiled:
|
||||
new_shape = list(x_aval.shape)
|
||||
new_shape.insert(all_gather_dimension, 1)
|
||||
@ -1282,6 +1287,10 @@ all_gather_p = core.AxisPrimitive('all_gather')
|
||||
all_gather_p.def_abstract_eval(_all_gather_abstract_eval)
|
||||
all_gather_p.def_impl(_all_gather_impl)
|
||||
mlir.register_lowering(all_gather_p, _all_gather_lowering)
|
||||
for p in ("cuda", "rocm", "tpu"):
|
||||
mlir.register_lowering(all_gather_p,
|
||||
partial(_all_gather_lowering, platform=p),
|
||||
platform=p)
|
||||
ad.deflinear2(all_gather_p, _all_gather_transpose_rule)
|
||||
batching.primitive_batchers[all_gather_p] = _all_gather_batcher
|
||||
batching.axis_primitive_batchers[all_gather_p] = _all_gather_batched_collective
|
||||
@ -1313,63 +1322,68 @@ def _reduce_scatter_via_reducer(x, *, reducer, scatter_dimension, axis_name,
|
||||
return outs
|
||||
|
||||
|
||||
def _reduce_scatter_lowering(prim, reducer, ctx, x,
|
||||
*, scatter_dimension, axis_name,
|
||||
axis_index_groups, axis_size, tiled):
|
||||
if ctx.module_context.platform in ("tpu", "cuda", "rocm"):
|
||||
x_aval, = ctx.avals_in
|
||||
aval_out, = ctx.avals_out
|
||||
scalar_aval = x_aval.update(shape=())
|
||||
replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name,
|
||||
axis_index_groups)
|
||||
scatter_out_shape = list(x_aval.shape)
|
||||
scatter_out_shape[scatter_dimension] //= axis_size
|
||||
axis_context = ctx.module_context.axis_context
|
||||
is_spmd = isinstance(
|
||||
axis_context,
|
||||
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
|
||||
)
|
||||
if is_spmd:
|
||||
# We want to emit the all-gather with global device IDs and a unique
|
||||
# channel ID, as otherwise it interprets the devices as replicas instead
|
||||
# of partitions - and XLA is configured with only a single replica.
|
||||
channel = ctx.module_context.new_channel()
|
||||
other_args = dict(
|
||||
channel_handle=hlo.ChannelHandle.get(
|
||||
channel, mlir.DEVICE_TO_DEVICE_TYPE),
|
||||
use_global_device_ids=ir.BoolAttr.get(True))
|
||||
else:
|
||||
other_args = {}
|
||||
op = hlo.ReduceScatterOp(
|
||||
mlir.aval_to_ir_type(x_aval.update(shape=scatter_out_shape)),
|
||||
x,
|
||||
scatter_dimension=mlir.i64_attr(scatter_dimension),
|
||||
replica_groups=_replica_groups_hlo(replica_groups),
|
||||
**other_args)
|
||||
scalar_type = mlir.aval_to_ir_type(scalar_aval)
|
||||
reducer_block = op.regions[0].blocks.append(scalar_type, scalar_type)
|
||||
with ir.InsertionPoint(reducer_block):
|
||||
lower_reducer = mlir.lower_fun(prim.bind, multiple_results=False)
|
||||
reducer_ctx = ctx.replace(primitive=None,
|
||||
avals_in=[scalar_aval] * 2,
|
||||
avals_out=[scalar_aval])
|
||||
out_nodes = lower_reducer(
|
||||
reducer_ctx, *([a] for a in reducer_block.arguments))
|
||||
hlo.ReturnOp(util.flatten(out_nodes))
|
||||
|
||||
if tiled:
|
||||
return op.results
|
||||
else:
|
||||
return hlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), op.result).results
|
||||
def _reduce_scatter_lowering(
|
||||
prim, reducer, ctx, x,
|
||||
*, scatter_dimension, axis_name,
|
||||
axis_index_groups, axis_size, tiled):
|
||||
x_aval, = ctx.avals_in
|
||||
aval_out, = ctx.avals_out
|
||||
scalar_aval = x_aval.update(shape=())
|
||||
replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name,
|
||||
axis_index_groups)
|
||||
scatter_out_shape = list(x_aval.shape)
|
||||
scatter_out_shape[scatter_dimension] //= axis_size
|
||||
axis_context = ctx.module_context.axis_context
|
||||
is_spmd = isinstance(
|
||||
axis_context,
|
||||
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
|
||||
)
|
||||
if is_spmd:
|
||||
# We want to emit the all-gather with global device IDs and a unique
|
||||
# channel ID, as otherwise it interprets the devices as replicas instead
|
||||
# of partitions - and XLA is configured with only a single replica.
|
||||
channel = ctx.module_context.new_channel()
|
||||
other_args = dict(
|
||||
channel_handle=hlo.ChannelHandle.get(
|
||||
channel, mlir.DEVICE_TO_DEVICE_TYPE),
|
||||
use_global_device_ids=ir.BoolAttr.get(True))
|
||||
else:
|
||||
return mlir.lower_fun(_reduce_scatter_via_reducer, multiple_results=False)(
|
||||
ctx, x,
|
||||
reducer=reducer,
|
||||
scatter_dimension=scatter_dimension,
|
||||
axis_name=axis_name,
|
||||
axis_index_groups=axis_index_groups,
|
||||
axis_size=axis_size,
|
||||
tiled=tiled)
|
||||
other_args = {}
|
||||
op = hlo.ReduceScatterOp(
|
||||
mlir.aval_to_ir_type(x_aval.update(shape=scatter_out_shape)),
|
||||
x,
|
||||
scatter_dimension=mlir.i64_attr(scatter_dimension),
|
||||
replica_groups=_replica_groups_hlo(replica_groups),
|
||||
**other_args)
|
||||
scalar_type = mlir.aval_to_ir_type(scalar_aval)
|
||||
reducer_block = op.regions[0].blocks.append(scalar_type, scalar_type)
|
||||
with ir.InsertionPoint(reducer_block):
|
||||
lower_reducer = mlir.lower_fun(prim.bind, multiple_results=False)
|
||||
reducer_ctx = ctx.replace(primitive=None,
|
||||
avals_in=[scalar_aval] * 2,
|
||||
avals_out=[scalar_aval])
|
||||
out_nodes = lower_reducer(
|
||||
reducer_ctx, *([a] for a in reducer_block.arguments))
|
||||
hlo.ReturnOp(util.flatten(out_nodes))
|
||||
|
||||
if tiled:
|
||||
return op.results
|
||||
else:
|
||||
return hlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), op.result).results
|
||||
|
||||
def _reduce_scatter_lowering_via_reducer(
|
||||
prim, reducer, ctx, x,
|
||||
*, scatter_dimension, axis_name,
|
||||
axis_index_groups, axis_size, tiled):
|
||||
|
||||
return mlir.lower_fun(_reduce_scatter_via_reducer, multiple_results=False)(
|
||||
ctx, x,
|
||||
reducer=reducer,
|
||||
scatter_dimension=scatter_dimension,
|
||||
axis_name=axis_name,
|
||||
axis_index_groups=axis_index_groups,
|
||||
axis_size=axis_size,
|
||||
tiled=tiled)
|
||||
|
||||
|
||||
def _reduce_scatter_abstract_eval(x, *, axis_name, scatter_dimension,
|
||||
@ -1449,9 +1463,17 @@ reduce_scatter_p.def_abstract_eval(_reduce_scatter_abstract_eval)
|
||||
ad.deflinear2(reduce_scatter_p, _reduce_scatter_transpose_rule)
|
||||
batching.primitive_batchers[reduce_scatter_p] = _reduce_scatter_batcher
|
||||
batching.axis_primitive_batchers[reduce_scatter_p] = _reduce_scatter_collective
|
||||
|
||||
mlir.register_lowering(
|
||||
reduce_scatter_p,
|
||||
partial(_reduce_scatter_lowering, lax.add_p, psum))
|
||||
partial(_reduce_scatter_lowering_via_reducer, lax.add_p, psum))
|
||||
reduce_scatter_lowering_for_psum = partial(_reduce_scatter_lowering,
|
||||
lax.add_p, psum)
|
||||
for p in ("tpu", "cuda", "rocm"):
|
||||
mlir.register_lowering(
|
||||
reduce_scatter_p, reduce_scatter_lowering_for_psum,
|
||||
platform=p)
|
||||
|
||||
core.axis_substitution_rules[reduce_scatter_p] = \
|
||||
partial(_subst_all_names_in_param, 'axis_name')
|
||||
|
||||
|
@ -13,12 +13,18 @@
|
||||
# limitations under the License.
|
||||
"""Tests for multi-platform and cross-platform JAX export."""
|
||||
|
||||
import math
|
||||
import re
|
||||
from typing import Literal
|
||||
from typing import Callable, Sequence
|
||||
|
||||
from absl import logging
|
||||
from absl.testing import absltest
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import pjit
|
||||
from jax._src import test_util as jtu
|
||||
from jax.experimental.export import export
|
||||
# TODO(necula): Move the primitive harness out of jax2tf so that we can move
|
||||
@ -46,7 +52,6 @@ _skip_cuda_lowering_unless_have_gpus = make_disjunction_regexp(
|
||||
"random_",
|
||||
)
|
||||
|
||||
|
||||
class PrimitiveTest(jtu.JaxTestCase):
|
||||
|
||||
@classmethod
|
||||
@ -88,8 +93,21 @@ class PrimitiveTest(jtu.JaxTestCase):
|
||||
for l in harness.jax_unimplemented:
|
||||
if l.filter(dtype=harness.dtype):
|
||||
unimplemented_platforms = unimplemented_platforms.union(l.devices)
|
||||
if (_skip_cuda_lowering_unless_have_gpus.search(harness.fullname)
|
||||
and all(d.platform != "gpu" for d in self.devices)):
|
||||
unimplemented_platforms.add("gpu")
|
||||
|
||||
logging.info("Harness is not implemented on %s", unimplemented_platforms)
|
||||
|
||||
self.export_and_compare_to_native(
|
||||
func_jax, *args,
|
||||
unimplemented_platforms=unimplemented_platforms)
|
||||
|
||||
def export_and_compare_to_native(
|
||||
self, func_jax: Callable,
|
||||
*args: jax.Array,
|
||||
unimplemented_platforms: set[str] = set(),
|
||||
skip_run_on_platforms: set[str] = set()):
|
||||
devices = [
|
||||
d
|
||||
for d in self.__class__.devices
|
||||
@ -99,14 +117,9 @@ class PrimitiveTest(jtu.JaxTestCase):
|
||||
# lowering_platforms uses "cuda" instead of "gpu"
|
||||
lowering_platforms: list[str] = [
|
||||
p if p != "gpu" else "cuda"
|
||||
for p in {"cpu", "gpu", "tpu"} - unimplemented_platforms
|
||||
for p in ("cpu", "gpu", "tpu")
|
||||
if p not in unimplemented_platforms
|
||||
]
|
||||
if (
|
||||
"cuda" in lowering_platforms
|
||||
and _skip_cuda_lowering_unless_have_gpus.search(harness.fullname)
|
||||
and all(d.platform != "gpu" for d in devices)
|
||||
):
|
||||
lowering_platforms.remove("cuda")
|
||||
|
||||
if len(lowering_platforms) <= 1:
|
||||
self.skipTest(
|
||||
@ -117,6 +130,9 @@ class PrimitiveTest(jtu.JaxTestCase):
|
||||
exp = export.export(func_jax, lowering_platforms=lowering_platforms)(*args)
|
||||
|
||||
for device in devices:
|
||||
if device.platform in skip_run_on_platforms:
|
||||
logging.info("Skipping running on %s", device)
|
||||
continue
|
||||
device_args = jax.tree_util.tree_map(
|
||||
lambda x: jax.device_put(x, device), args
|
||||
)
|
||||
@ -127,6 +143,32 @@ class PrimitiveTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(native_res, exported_res)
|
||||
# TODO(necula): Check HLO equivalence for the ultimate test.
|
||||
|
||||
def test_psum_scatter(self):
|
||||
f = jax.jit(jax.pmap(lambda x: lax.psum_scatter(x, 'i'),
|
||||
axis_name='i',
|
||||
devices=jax.devices()[:1]))
|
||||
|
||||
shape = (1, 1, 8)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
self.export_and_compare_to_native(f, x)
|
||||
|
||||
# The lowering rule for all_gather has special cases for bool.
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(dtype=dtype)
|
||||
for dtype in [np.bool_, np.float32]],
|
||||
)
|
||||
def test_all_gather(self, *, dtype):
|
||||
f = jax.jit(jax.pmap(lambda x: lax.all_gather(x, 'i'),
|
||||
axis_name='i',
|
||||
devices=jax.devices()[:1]))
|
||||
|
||||
shape = (1, 4)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
if dtype == np.bool_:
|
||||
x = (x % 2).astype(np.bool_)
|
||||
self.export_and_compare_to_native(f, x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user