mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Mosaic] Add min/max lowering rules for Mosaic.
PiperOrigin-RevId: 568929392
This commit is contained in:
parent
1885c4933c
commit
87af945cbe
@ -927,6 +927,22 @@ lowering_rules[lax.max_p] = _max_lowering_rule
|
||||
skip_mlir_conversions.add(lax.max_p)
|
||||
|
||||
|
||||
def _min_lowering_rule(ctx: LoweringRuleContext, x, y):
|
||||
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
||||
(aval_out,) = ctx.avals_out
|
||||
if jnp.issubdtype(aval_out.dtype, jnp.signedinteger):
|
||||
return arith.MinSIOp(x, y).result
|
||||
elif jnp.issubdtype(aval_out.dtype, jnp.unsignedinteger):
|
||||
return arith.MinUIOp(x, y).result
|
||||
elif jnp.issubdtype(aval_out.dtype, jnp.floating):
|
||||
return arith.MinimumFOp(x, y).result
|
||||
raise NotImplementedError(aval_out.dtype)
|
||||
|
||||
|
||||
lowering_rules[lax.min_p] = _min_lowering_rule
|
||||
skip_mlir_conversions.add(lax.min_p)
|
||||
|
||||
|
||||
def _sub_lowering_rule(ctx: LoweringRuleContext, x, y):
|
||||
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
|
||||
(aval_out,) = ctx.avals_out
|
||||
|
@ -1646,9 +1646,15 @@ _register_rule("arith.remsi")(
|
||||
_register_rule("arith.maximumf")(
|
||||
functools.partial(_elementwise_op_rule, arith.MaximumFOp)
|
||||
)
|
||||
_register_rule("arith.maxsi")(
|
||||
functools.partial(_elementwise_op_rule, arith.MaxSIOp)
|
||||
)
|
||||
_register_rule("arith.minimumf")(
|
||||
functools.partial(_elementwise_op_rule, arith.MinimumFOp)
|
||||
)
|
||||
_register_rule("arith.minsi")(
|
||||
functools.partial(_elementwise_op_rule, arith.MinSIOp)
|
||||
)
|
||||
_register_rule("arith.select")(
|
||||
functools.partial(_elementwise_op_rule, arith.SelectOp))
|
||||
_register_rule("arith.index_cast")(
|
||||
|
Loading…
x
Reference in New Issue
Block a user