mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Remove explicit_type argument from _nary_lower_hlo.
PiperOrigin-RevId: 683395436
This commit is contained in:
parent
a9e9f97f00
commit
0854dc24e8
@ -2186,12 +2186,8 @@ def multi_sharding_in_dim(ctx, ops, in_avals, out_aval):
|
||||
|
||||
|
||||
def _nary_lower_hlo(op: Callable, ctx,
|
||||
*args: ir.Value,
|
||||
explicit_type=False, **params) -> Sequence[ir.Value]:
|
||||
*args: ir.Value, **params) -> Sequence[ir.Value]:
|
||||
"""Lowers an elementwise operator to its MLIR equivalent.
|
||||
|
||||
Args:
|
||||
explicit_type: does the MLIR op require its output type to be provided?
|
||||
"""
|
||||
del params
|
||||
avals_in, (aval_out,) = ctx.avals_in, ctx.avals_out
|
||||
@ -2199,10 +2195,7 @@ def _nary_lower_hlo(op: Callable, ctx,
|
||||
if config.sharding_in_types.value:
|
||||
args = multi_sharding_in_dim(ctx, args, avals_in, aval_out)
|
||||
|
||||
if explicit_type:
|
||||
out = op(mlir.aval_to_ir_type(aval_out), *args)
|
||||
else:
|
||||
out = op(*args)
|
||||
out = op(*args)
|
||||
if config.sharding_in_types.value:
|
||||
if config.use_shardy_partitioner.value:
|
||||
out_sp = aval_out.sharding._to_sdy_sharding(aval_out.ndim)
|
||||
|
Loading…
x
Reference in New Issue
Block a user