diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 8689bc4ee..05c4e6946 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -940,6 +940,15 @@ def _exp_lowering_rule(ctx: LoweringRuleContext, x): lowering_rules[lax.exp_p] = _exp_lowering_rule +def _exp2_lowering_rule(ctx: LoweringRuleContext, x): + # exp2 in JAX lowers to exp(ln2 * x), not to pow2. We match that behavior + # here. + return lower_fun(lambda x: jnp.exp(np.log(2) * x), multiple_results=False)( + ctx, x) + + +lowering_rules[lax.exp2_p] = _exp2_lowering_rule + def _logistic_lowering_rule(ctx: LoweringRuleContext, x): neg_x = arith.NegFOp(x).result exp_neg_x = math.ExpOp(neg_x).result