From ab161bbd404f2f587b49cde880f1550381d9964d Mon Sep 17 00:00:00 2001 From: George Necula Date: Thu, 12 Oct 2023 13:32:47 -0700 Subject: [PATCH] Cleanup lowering rule for hlo_unshard, to remove platform dependence. PiperOrigin-RevId: 572997889 --- jax/_src/interpreters/pxla.py | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 0b61aaed5..1a6071829 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1324,18 +1324,11 @@ def _axis_groups(mesh_spec, mesh_axes): # TODO(b/110096942): more efficient gather -def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, platform): +def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs): if aval is core.abstract_token: return xs elif isinstance(aval, core.ShapedArray): x, = xs - # TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU - convert_bool = (np.issubdtype(aval.dtype, np.bool_) - and platform in ('cpu', 'gpu')) - if convert_bool: - aval = aval.update(dtype=np.dtype(np.float32)) - x = hlo.ConvertOp(mlir.aval_to_ir_type(aval), x).result - dims = list(aval.shape) padded_aval = aval.update(shape=[axis_env.sizes[-1]] + dims) padded = mlir.full_like_aval(ctx, 0, padded_aval) @@ -1353,17 +1346,8 @@ def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, pl perm.insert(out_axis, 0) transposed_dims = list(dims) transposed_dims.insert(out_axis, axis_env.sizes[-1]) - aval = aval.update(shape=transposed_dims) out = hlo.TransposeOp(out, mlir.dense_int_elements(perm)).result - # TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU - if convert_bool: - float_zero = mlir.full_like_aval(ctx, 0, padded_aval) - out = hlo.CompareOp( - out, - float_zero, - hlo.ComparisonDirectionAttr.get("NE"), - compare_type=hlo.ComparisonTypeAttr.get("FLOAT")).result return out else: raise TypeError(aval) @@ -1402,8 +1386,7 @@ def _pmap_lowering(ctx, *in_nodes, axis_name, *in_nodes_sharded, dim_var_values=ctx.dim_var_values) out_avals = [v.aval for v in call_jaxpr.outvars] - outs = [_hlo_unshard(ctx, aval, new_env, out_axis, shard, - platform=ctx.module_context.platform) + outs = [_hlo_unshard(ctx, aval, new_env, out_axis, shard) for aval, out_axis, shard in zip(out_avals, out_axes, sharded_outs)] return outs