Remove explicit_type argument from _nary_lower_hlo.

PiperOrigin-RevId: 683395436
This commit is contained in:
jax authors 2024-10-07 18:00:56 -07:00
parent a9e9f97f00
commit 0854dc24e8

View File

@ -2186,12 +2186,8 @@ def multi_sharding_in_dim(ctx, ops, in_avals, out_aval):
def _nary_lower_hlo(op: Callable, ctx,
*args: ir.Value,
explicit_type=False, **params) -> Sequence[ir.Value]:
*args: ir.Value, **params) -> Sequence[ir.Value]:
"""Lowers an elementwise operator to its MLIR equivalent.
Args:
explicit_type: does the MLIR op require its output type to be provided?
"""
del params
avals_in, (aval_out,) = ctx.avals_in, ctx.avals_out
@ -2199,10 +2195,7 @@ def _nary_lower_hlo(op: Callable, ctx,
if config.sharding_in_types.value:
args = multi_sharding_in_dim(ctx, args, avals_in, aval_out)
if explicit_type:
out = op(mlir.aval_to_ir_type(aval_out), *args)
else:
out = op(*args)
out = op(*args)
if config.sharding_in_types.value:
if config.use_shardy_partitioner.value:
out_sp = aval_out.sharding._to_sdy_sharding(aval_out.ndim)