[Pallas] Add Mosaic lowering for exp2

PiperOrigin-RevId: 557646023
This commit is contained in:
Sharad Vikram 2023-08-16 17:04:28 -07:00 committed by jax authors
parent caee3120fd
commit 785af827be

View File

@ -940,6 +940,15 @@ def _exp_lowering_rule(ctx: LoweringRuleContext, x):
lowering_rules[lax.exp_p] = _exp_lowering_rule
def _exp2_lowering_rule(ctx: LoweringRuleContext, x):
# exp2 in JAX lowers to exp(ln2 * x), not to pow2. We match that behavior
# here.
return lower_fun(lambda x: jnp.exp(np.log(2) * x), multiple_results=False)(
ctx, x)
lowering_rules[lax.exp2_p] = _exp2_lowering_rule
def _logistic_lowering_rule(ctx: LoweringRuleContext, x):
neg_x = arith.NegFOp(x).result
exp_neg_x = math.ExpOp(neg_x).result