Always lower reduce_scatter_p as an HLO ReduceScatter.

We don't need the fallback path for CPU: XLA:CPU already does its own lowering of ReduceScatter as AllReduce + DynamicSlice, and I plan to teach it a direct lowering in an upcoming change.

PiperOrigin-RevId: 586311031
This commit is contained in:
Peter Hawkins 2023-11-29 05:37:19 -08:00 committed by jax authors
parent 86d9398078
commit 458a8962be

View File

@ -1276,33 +1276,8 @@ batching.axis_primitive_batchers[all_gather_p] = _all_gather_batched_collective
core.axis_substitution_rules[all_gather_p] = partial(_subst_all_names_in_param, 'axis_name')
def _reduce_scatter_via_reducer(x, *, reducer, scatter_dimension, axis_name,
axis_index_groups, axis_size, tiled):
index = _index_in_group(axis_name, axis_index_groups)
scatter_dim_input_size = x.shape[scatter_dimension]
if tiled and scatter_dim_input_size % axis_size != 0:
raise ValueError(f"tiled reduce_scatter operand scatter dimension size "
f"{scatter_dim_input_size} must be divisible by "
f"shard count {axis_size}")
elif not tiled and scatter_dim_input_size != axis_size:
raise ValueError(f"reduce_scatter operand scatter dimension size "
f"{scatter_dim_input_size} must match shard count"
f"{axis_size}")
scatter_dim_output_size = scatter_dim_input_size // axis_size
outs = reducer(x, axis_name=axis_name, axis_index_groups=axis_index_groups)
outs = slicing.dynamic_slice_in_dim(
outs,
start_index=index * scatter_dim_output_size,
slice_size=scatter_dim_output_size,
axis=scatter_dimension)
if not tiled:
outs = lax.squeeze(outs, [scatter_dimension])
return outs
def _reduce_scatter_lowering(
prim, reducer, ctx, x,
prim, ctx, x,
*, scatter_dimension, axis_name,
axis_index_groups, axis_size, tiled):
x_aval, = ctx.avals_in
@ -1350,19 +1325,6 @@ def _reduce_scatter_lowering(
else:
return [hlo.reshape(mlir.aval_to_ir_type(aval_out), op.result)]
def _reduce_scatter_lowering_via_reducer(
prim, reducer, ctx, x,
*, scatter_dimension, axis_name,
axis_index_groups, axis_size, tiled):
return mlir.lower_fun(_reduce_scatter_via_reducer, multiple_results=False)(
ctx, x,
reducer=reducer,
scatter_dimension=scatter_dimension,
axis_name=axis_name,
axis_index_groups=axis_index_groups,
axis_size=axis_size,
tiled=tiled)
def _reduce_scatter_abstract_eval(x, *, axis_name, scatter_dimension,
@ -1443,15 +1405,8 @@ ad.deflinear2(reduce_scatter_p, _reduce_scatter_transpose_rule)
batching.primitive_batchers[reduce_scatter_p] = _reduce_scatter_batcher
batching.axis_primitive_batchers[reduce_scatter_p] = _reduce_scatter_collective
mlir.register_lowering(
reduce_scatter_p,
partial(_reduce_scatter_lowering_via_reducer, lax.add_p, psum))
reduce_scatter_lowering_for_psum = partial(_reduce_scatter_lowering,
lax.add_p, psum)
for p in ("tpu", "cuda", "rocm"):
mlir.register_lowering(
reduce_scatter_p, reduce_scatter_lowering_for_psum,
platform=p)
mlir.register_lowering(reduce_scatter_p,
partial(_reduce_scatter_lowering, lax.add_p))
core.axis_substitution_rules[reduce_scatter_p] = \
partial(_subst_all_names_in_param, 'axis_name')