mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Pallas] Add Mosaic lowering for exp2
PiperOrigin-RevId: 557646023
This commit is contained in:
parent
caee3120fd
commit
785af827be
@ -940,6 +940,15 @@ def _exp_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
lowering_rules[lax.exp_p] = _exp_lowering_rule
|
||||
|
||||
|
||||
def _exp2_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
# exp2 in JAX lowers to exp(ln2 * x), not to pow2. We match that behavior
|
||||
# here.
|
||||
return lower_fun(lambda x: jnp.exp(np.log(2) * x), multiple_results=False)(
|
||||
ctx, x)
|
||||
|
||||
|
||||
lowering_rules[lax.exp2_p] = _exp2_lowering_rule
|
||||
|
||||
def _logistic_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
neg_x = arith.NegFOp(x).result
|
||||
exp_neg_x = math.ExpOp(neg_x).result
|
||||
|
Loading…
x
Reference in New Issue
Block a user