[Pallas] Add Mosaic lowering rule for fpowi.

PiperOrigin-RevId: 565800521
This commit is contained in:
Emily Fertig 2023-09-15 16:00:19 -07:00 committed by jax authors
parent 65de2cf907
commit e78d8a321e

View File

@ -30,6 +30,7 @@ from jax._src import source_info_util
from jax._src import state
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax as lax_internal
from jax._src.lax.control_flow import for_loop
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
@ -1036,6 +1037,14 @@ lowering_rules[lax.pow_p] = _pow_lowering_rule
skip_mlir_conversions.add(lax.pow_p)
def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, *, y):
return lower_fun(lax_internal._integer_pow, multiple_results=False)(
ctx, x, y=y)
lowering_rules[lax.integer_pow_p] = _integer_pow_lowering_rule
def _exp2_lowering_rule(ctx: LoweringRuleContext, x):
# exp2 in JAX lowers to exp(ln2 * x), not to pow2. We match that behavior
# here.