From 458a8962be48eac9a45b432bd385c88003fd278f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 29 Nov 2023 05:37:19 -0800 Subject: [PATCH] Always lower reduce_scatter_p as an HLO ReduceScatter. We don't need the fallback path for CPU: XLA:CPU already does its own lowering of ReduceScatter as AllReduce + DynamicSlice, and I plan to teach it a direct lowering in an upcoming change. PiperOrigin-RevId: 586311031 --- jax/_src/lax/parallel.py | 51 +++------------------------------------- 1 file changed, 3 insertions(+), 48 deletions(-) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 07bf4946e..d29d8bdb6 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -1276,33 +1276,8 @@ batching.axis_primitive_batchers[all_gather_p] = _all_gather_batched_collective core.axis_substitution_rules[all_gather_p] = partial(_subst_all_names_in_param, 'axis_name') -def _reduce_scatter_via_reducer(x, *, reducer, scatter_dimension, axis_name, - axis_index_groups, axis_size, tiled): - index = _index_in_group(axis_name, axis_index_groups) - scatter_dim_input_size = x.shape[scatter_dimension] - if tiled and scatter_dim_input_size % axis_size != 0: - raise ValueError(f"tiled reduce_scatter operand scatter dimension size " - f"{scatter_dim_input_size} must be divisible by " - f"shard count {axis_size}") - elif not tiled and scatter_dim_input_size != axis_size: - raise ValueError(f"reduce_scatter operand scatter dimension size " - f"{scatter_dim_input_size} must match shard count" - f"{axis_size}") - scatter_dim_output_size = scatter_dim_input_size // axis_size - - outs = reducer(x, axis_name=axis_name, axis_index_groups=axis_index_groups) - outs = slicing.dynamic_slice_in_dim( - outs, - start_index=index * scatter_dim_output_size, - slice_size=scatter_dim_output_size, - axis=scatter_dimension) - if not tiled: - outs = lax.squeeze(outs, [scatter_dimension]) - return outs - - def _reduce_scatter_lowering( - prim, reducer, ctx, x, + prim, ctx, x, *, scatter_dimension, axis_name, axis_index_groups, axis_size, tiled): x_aval, = ctx.avals_in @@ -1350,19 +1325,6 @@ def _reduce_scatter_lowering( else: return [hlo.reshape(mlir.aval_to_ir_type(aval_out), op.result)] -def _reduce_scatter_lowering_via_reducer( - prim, reducer, ctx, x, - *, scatter_dimension, axis_name, - axis_index_groups, axis_size, tiled): - - return mlir.lower_fun(_reduce_scatter_via_reducer, multiple_results=False)( - ctx, x, - reducer=reducer, - scatter_dimension=scatter_dimension, - axis_name=axis_name, - axis_index_groups=axis_index_groups, - axis_size=axis_size, - tiled=tiled) def _reduce_scatter_abstract_eval(x, *, axis_name, scatter_dimension, @@ -1443,15 +1405,8 @@ ad.deflinear2(reduce_scatter_p, _reduce_scatter_transpose_rule) batching.primitive_batchers[reduce_scatter_p] = _reduce_scatter_batcher batching.axis_primitive_batchers[reduce_scatter_p] = _reduce_scatter_collective -mlir.register_lowering( - reduce_scatter_p, - partial(_reduce_scatter_lowering_via_reducer, lax.add_p, psum)) -reduce_scatter_lowering_for_psum = partial(_reduce_scatter_lowering, - lax.add_p, psum) -for p in ("tpu", "cuda", "rocm"): - mlir.register_lowering( - reduce_scatter_p, reduce_scatter_lowering_for_psum, - platform=p) +mlir.register_lowering(reduce_scatter_p, + partial(_reduce_scatter_lowering, lax.add_p)) core.axis_substitution_rules[reduce_scatter_p] = \ partial(_subst_all_names_in_param, 'axis_name')