mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Always lower reduce_scatter_p as an HLO ReduceScatter.
We don't need the fallback path for CPU: XLA:CPU already does its own lowering of ReduceScatter as AllReduce + DynamicSlice, and I plan to teach it a direct lowering in an upcoming change. PiperOrigin-RevId: 586311031
This commit is contained in:
parent
86d9398078
commit
458a8962be
@ -1276,33 +1276,8 @@ batching.axis_primitive_batchers[all_gather_p] = _all_gather_batched_collective
|
||||
core.axis_substitution_rules[all_gather_p] = partial(_subst_all_names_in_param, 'axis_name')
|
||||
|
||||
|
||||
def _reduce_scatter_via_reducer(x, *, reducer, scatter_dimension, axis_name,
|
||||
axis_index_groups, axis_size, tiled):
|
||||
index = _index_in_group(axis_name, axis_index_groups)
|
||||
scatter_dim_input_size = x.shape[scatter_dimension]
|
||||
if tiled and scatter_dim_input_size % axis_size != 0:
|
||||
raise ValueError(f"tiled reduce_scatter operand scatter dimension size "
|
||||
f"{scatter_dim_input_size} must be divisible by "
|
||||
f"shard count {axis_size}")
|
||||
elif not tiled and scatter_dim_input_size != axis_size:
|
||||
raise ValueError(f"reduce_scatter operand scatter dimension size "
|
||||
f"{scatter_dim_input_size} must match shard count"
|
||||
f"{axis_size}")
|
||||
scatter_dim_output_size = scatter_dim_input_size // axis_size
|
||||
|
||||
outs = reducer(x, axis_name=axis_name, axis_index_groups=axis_index_groups)
|
||||
outs = slicing.dynamic_slice_in_dim(
|
||||
outs,
|
||||
start_index=index * scatter_dim_output_size,
|
||||
slice_size=scatter_dim_output_size,
|
||||
axis=scatter_dimension)
|
||||
if not tiled:
|
||||
outs = lax.squeeze(outs, [scatter_dimension])
|
||||
return outs
|
||||
|
||||
|
||||
def _reduce_scatter_lowering(
|
||||
prim, reducer, ctx, x,
|
||||
prim, ctx, x,
|
||||
*, scatter_dimension, axis_name,
|
||||
axis_index_groups, axis_size, tiled):
|
||||
x_aval, = ctx.avals_in
|
||||
@ -1350,19 +1325,6 @@ def _reduce_scatter_lowering(
|
||||
else:
|
||||
return [hlo.reshape(mlir.aval_to_ir_type(aval_out), op.result)]
|
||||
|
||||
def _reduce_scatter_lowering_via_reducer(
|
||||
prim, reducer, ctx, x,
|
||||
*, scatter_dimension, axis_name,
|
||||
axis_index_groups, axis_size, tiled):
|
||||
|
||||
return mlir.lower_fun(_reduce_scatter_via_reducer, multiple_results=False)(
|
||||
ctx, x,
|
||||
reducer=reducer,
|
||||
scatter_dimension=scatter_dimension,
|
||||
axis_name=axis_name,
|
||||
axis_index_groups=axis_index_groups,
|
||||
axis_size=axis_size,
|
||||
tiled=tiled)
|
||||
|
||||
|
||||
def _reduce_scatter_abstract_eval(x, *, axis_name, scatter_dimension,
|
||||
@ -1443,15 +1405,8 @@ ad.deflinear2(reduce_scatter_p, _reduce_scatter_transpose_rule)
|
||||
batching.primitive_batchers[reduce_scatter_p] = _reduce_scatter_batcher
|
||||
batching.axis_primitive_batchers[reduce_scatter_p] = _reduce_scatter_collective
|
||||
|
||||
mlir.register_lowering(
|
||||
reduce_scatter_p,
|
||||
partial(_reduce_scatter_lowering_via_reducer, lax.add_p, psum))
|
||||
reduce_scatter_lowering_for_psum = partial(_reduce_scatter_lowering,
|
||||
lax.add_p, psum)
|
||||
for p in ("tpu", "cuda", "rocm"):
|
||||
mlir.register_lowering(
|
||||
reduce_scatter_p, reduce_scatter_lowering_for_psum,
|
||||
platform=p)
|
||||
mlir.register_lowering(reduce_scatter_p,
|
||||
partial(_reduce_scatter_lowering, lax.add_p))
|
||||
|
||||
core.axis_substitution_rules[reduce_scatter_p] = \
|
||||
partial(_subst_all_names_in_param, 'axis_name')
|
||||
|
Loading…
x
Reference in New Issue
Block a user