Cleanup lowering rule for hlo_unshard, to remove platform dependence.

PiperOrigin-RevId: 572997889
This commit is contained in:
George Necula 2023-10-12 13:32:47 -07:00 committed by jax authors
parent 294fe80650
commit ab161bbd40

View File

@ -1324,18 +1324,11 @@ def _axis_groups(mesh_spec, mesh_axes):
# TODO(b/110096942): more efficient gather
def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, platform):
def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs):
if aval is core.abstract_token:
return xs
elif isinstance(aval, core.ShapedArray):
x, = xs
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
convert_bool = (np.issubdtype(aval.dtype, np.bool_)
and platform in ('cpu', 'gpu'))
if convert_bool:
aval = aval.update(dtype=np.dtype(np.float32))
x = hlo.ConvertOp(mlir.aval_to_ir_type(aval), x).result
dims = list(aval.shape)
padded_aval = aval.update(shape=[axis_env.sizes[-1]] + dims)
padded = mlir.full_like_aval(ctx, 0, padded_aval)
@ -1353,17 +1346,8 @@ def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, pl
perm.insert(out_axis, 0)
transposed_dims = list(dims)
transposed_dims.insert(out_axis, axis_env.sizes[-1])
aval = aval.update(shape=transposed_dims)
out = hlo.TransposeOp(out, mlir.dense_int_elements(perm)).result
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
if convert_bool:
float_zero = mlir.full_like_aval(ctx, 0, padded_aval)
out = hlo.CompareOp(
out,
float_zero,
hlo.ComparisonDirectionAttr.get("NE"),
compare_type=hlo.ComparisonTypeAttr.get("FLOAT")).result
return out
else:
raise TypeError(aval)
@ -1402,8 +1386,7 @@ def _pmap_lowering(ctx, *in_nodes, axis_name,
*in_nodes_sharded,
dim_var_values=ctx.dim_var_values)
out_avals = [v.aval for v in call_jaxpr.outvars]
outs = [_hlo_unshard(ctx, aval, new_env, out_axis, shard,
platform=ctx.module_context.platform)
outs = [_hlo_unshard(ctx, aval, new_env, out_axis, shard)
for aval, out_axis, shard in zip(out_avals, out_axes, sharded_outs)]
return outs