mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Cleanup lowering rule for hlo_unshard, to remove platform dependence.
PiperOrigin-RevId: 572997889
This commit is contained in:
parent
294fe80650
commit
ab161bbd40
@ -1324,18 +1324,11 @@ def _axis_groups(mesh_spec, mesh_axes):
|
||||
|
||||
|
||||
# TODO(b/110096942): more efficient gather
|
||||
def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, platform):
|
||||
def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs):
|
||||
if aval is core.abstract_token:
|
||||
return xs
|
||||
elif isinstance(aval, core.ShapedArray):
|
||||
x, = xs
|
||||
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
|
||||
convert_bool = (np.issubdtype(aval.dtype, np.bool_)
|
||||
and platform in ('cpu', 'gpu'))
|
||||
if convert_bool:
|
||||
aval = aval.update(dtype=np.dtype(np.float32))
|
||||
x = hlo.ConvertOp(mlir.aval_to_ir_type(aval), x).result
|
||||
|
||||
dims = list(aval.shape)
|
||||
padded_aval = aval.update(shape=[axis_env.sizes[-1]] + dims)
|
||||
padded = mlir.full_like_aval(ctx, 0, padded_aval)
|
||||
@ -1353,17 +1346,8 @@ def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, pl
|
||||
perm.insert(out_axis, 0)
|
||||
transposed_dims = list(dims)
|
||||
transposed_dims.insert(out_axis, axis_env.sizes[-1])
|
||||
aval = aval.update(shape=transposed_dims)
|
||||
out = hlo.TransposeOp(out, mlir.dense_int_elements(perm)).result
|
||||
|
||||
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
|
||||
if convert_bool:
|
||||
float_zero = mlir.full_like_aval(ctx, 0, padded_aval)
|
||||
out = hlo.CompareOp(
|
||||
out,
|
||||
float_zero,
|
||||
hlo.ComparisonDirectionAttr.get("NE"),
|
||||
compare_type=hlo.ComparisonTypeAttr.get("FLOAT")).result
|
||||
return out
|
||||
else:
|
||||
raise TypeError(aval)
|
||||
@ -1402,8 +1386,7 @@ def _pmap_lowering(ctx, *in_nodes, axis_name,
|
||||
*in_nodes_sharded,
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
out_avals = [v.aval for v in call_jaxpr.outvars]
|
||||
outs = [_hlo_unshard(ctx, aval, new_env, out_axis, shard,
|
||||
platform=ctx.module_context.platform)
|
||||
outs = [_hlo_unshard(ctx, aval, new_env, out_axis, shard)
|
||||
for aval, out_axis, shard in zip(out_avals, out_axes, sharded_outs)]
|
||||
return outs
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user