diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index b8e01bf55..d0dfd0eb7 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -18,7 +18,7 @@ from __future__ import annotations import dataclasses import functools import operator -from typing import Any, Dict, Sequence, Tuple +from typing import Any, Callable, Dict, Sequence, Tuple import zlib import jax @@ -432,23 +432,39 @@ for primitive, fn in _TRITON_FN_MAPPING.items(): triton_lowering_rules[primitive] = rule -def _clamp_lowering_rule(ctx: TritonLoweringRuleContext, min, operand, max): - operand = tl.math.max(operand, min, _builder=ctx.builder) - return tl.math.min(operand, max, _builder=ctx.builder) +def _integer_pow(a, *, y): + if y == 2: + return a * a + if y == 3: + return a * a * a + if y == -2: + return 1.0 / (a * a) + return jax.lax.pow(a, y) -triton_lowering_rules[lax.clamp_p] = _clamp_lowering_rule +def lower_fun( + fun: Callable[..., Any], *, multiple_results: bool +) -> Callable[..., Any]: + fn = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),) + + def f_lowered(ctx: TritonLoweringRuleContext, *args, **params): + wrapped_fun = lu.wrap_init(fn, params) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) + jaxpr = jax_core.ClosedJaxpr(jaxpr, consts) + out = _closed_call_lowering_rule(ctx, *args, call_jaxpr=jaxpr) + return out if multiple_results else out[0] + + return f_lowered -def _logistic_lowering_rule(ctx: TritonLoweringRuleContext, a): - one_ = tl.core._to_tensor(1.0, ctx.builder) - x = tl.exp(a.__neg__(_builder=ctx.builder), _builder=ctx.builder) - x = x.__add__(one_, _builder=ctx.builder) - x = one_.__truediv__(x, _builder=ctx.builder) - return x +_JAX_FN_MAPPING = { + lax.clamp_p: lambda min, a, max: jnp.minimum(jnp.maximum(min, a), max), + lax.integer_pow_p: _integer_pow, + lax.logistic_p: lambda a: 1 / (1 + jnp.exp(-a)), +} - -triton_lowering_rules[lax.logistic_p] = _logistic_lowering_rule +for primitive, fn in _JAX_FN_MAPPING.items(): + triton_lowering_rules[primitive] = lower_fun(fn, multiple_results=False) def _div_lowering_rule(ctx: TritonLoweringRuleContext, a, b): @@ -471,21 +487,6 @@ def _iota_lowering_rule( triton_lowering_rules[lax.iota_p] = _iota_lowering_rule -def _integer_pow_lowering_rule(ctx: TritonLoweringRuleContext, a, *, y): - if y == 2: - return a.__mul__(a, _builder=ctx.builder) - if y == 3: - return a.__mul__(a.__mul__(a, _builder=ctx.builder), _builder=ctx.builder) - if y == -2: - one_ = tl.core._to_tensor(1.0, ctx.builder) - a_sq = a.__mul__(a, _builder=ctx.builder) - return one_.__truediv__(a_sq, _builder=ctx.builder) - return tl.math.pow(a, y, _builder=ctx.builder) - - -triton_lowering_rules[lax.integer_pow_p] = _integer_pow_lowering_rule - - def _convert_element_type_lowering_rule( ctx: TritonLoweringRuleContext, a, *, new_dtype, weak_type ):