[Mosaic] Add min/max lowering rules for Mosaic.

PiperOrigin-RevId: 568929392
This commit is contained in:
Emily Fertig 2023-09-27 12:34:31 -07:00 committed by jax authors
parent 1885c4933c
commit 87af945cbe
2 changed files with 22 additions and 0 deletions

View File

@ -927,6 +927,22 @@ lowering_rules[lax.max_p] = _max_lowering_rule
skip_mlir_conversions.add(lax.max_p)
def _min_lowering_rule(ctx: LoweringRuleContext, x, y):
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
(aval_out,) = ctx.avals_out
if jnp.issubdtype(aval_out.dtype, jnp.signedinteger):
return arith.MinSIOp(x, y).result
elif jnp.issubdtype(aval_out.dtype, jnp.unsignedinteger):
return arith.MinUIOp(x, y).result
elif jnp.issubdtype(aval_out.dtype, jnp.floating):
return arith.MinimumFOp(x, y).result
raise NotImplementedError(aval_out.dtype)
lowering_rules[lax.min_p] = _min_lowering_rule
skip_mlir_conversions.add(lax.min_p)
def _sub_lowering_rule(ctx: LoweringRuleContext, x, y):
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
(aval_out,) = ctx.avals_out

View File

@ -1646,9 +1646,15 @@ _register_rule("arith.remsi")(
_register_rule("arith.maximumf")(
functools.partial(_elementwise_op_rule, arith.MaximumFOp)
)
_register_rule("arith.maxsi")(
functools.partial(_elementwise_op_rule, arith.MaxSIOp)
)
_register_rule("arith.minimumf")(
functools.partial(_elementwise_op_rule, arith.MinimumFOp)
)
_register_rule("arith.minsi")(
functools.partial(_elementwise_op_rule, arith.MinSIOp)
)
_register_rule("arith.select")(
functools.partial(_elementwise_op_rule, arith.SelectOp))
_register_rule("arith.index_cast")(