mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Pallas] Add Mosaic lowering rule for fpowi.
PiperOrigin-RevId: 565800521
This commit is contained in:
parent
65de2cf907
commit
e78d8a321e
@ -30,6 +30,7 @@ from jax._src import source_info_util
|
||||
from jax._src import state
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax.control_flow import for_loop
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import arith
|
||||
@ -1036,6 +1037,14 @@ lowering_rules[lax.pow_p] = _pow_lowering_rule
|
||||
skip_mlir_conversions.add(lax.pow_p)
|
||||
|
||||
|
||||
def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, *, y):
|
||||
return lower_fun(lax_internal._integer_pow, multiple_results=False)(
|
||||
ctx, x, y=y)
|
||||
|
||||
|
||||
lowering_rules[lax.integer_pow_p] = _integer_pow_lowering_rule
|
||||
|
||||
|
||||
def _exp2_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
# exp2 in JAX lowers to exp(ln2 * x), not to pow2. We match that behavior
|
||||
# here.
|
||||
|
Loading…
x
Reference in New Issue
Block a user