[pallas:gpu] Lower more complex primitives using JAX functions in terms of more basic primitives.

PiperOrigin-RevId: 575883386
This commit is contained in:
Chris Jones 2023-10-23 11:45:02 -07:00 committed by jax authors
parent 8fa287e4e7
commit b61af5a104

View File

@ -18,7 +18,7 @@ from __future__ import annotations
import dataclasses
import functools
import operator
from typing import Any, Dict, Sequence, Tuple
from typing import Any, Callable, Dict, Sequence, Tuple
import zlib
import jax
@ -432,23 +432,39 @@ for primitive, fn in _TRITON_FN_MAPPING.items():
triton_lowering_rules[primitive] = rule
def _clamp_lowering_rule(ctx: TritonLoweringRuleContext, min, operand, max):
operand = tl.math.max(operand, min, _builder=ctx.builder)
return tl.math.min(operand, max, _builder=ctx.builder)
def _integer_pow(a, *, y):
if y == 2:
return a * a
if y == 3:
return a * a * a
if y == -2:
return 1.0 / (a * a)
return jax.lax.pow(a, y)
triton_lowering_rules[lax.clamp_p] = _clamp_lowering_rule
def lower_fun(
fun: Callable[..., Any], *, multiple_results: bool
) -> Callable[..., Any]:
fn = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
def f_lowered(ctx: TritonLoweringRuleContext, *args, **params):
wrapped_fun = lu.wrap_init(fn, params)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
jaxpr = jax_core.ClosedJaxpr(jaxpr, consts)
out = _closed_call_lowering_rule(ctx, *args, call_jaxpr=jaxpr)
return out if multiple_results else out[0]
return f_lowered
def _logistic_lowering_rule(ctx: TritonLoweringRuleContext, a):
one_ = tl.core._to_tensor(1.0, ctx.builder)
x = tl.exp(a.__neg__(_builder=ctx.builder), _builder=ctx.builder)
x = x.__add__(one_, _builder=ctx.builder)
x = one_.__truediv__(x, _builder=ctx.builder)
return x
_JAX_FN_MAPPING = {
lax.clamp_p: lambda min, a, max: jnp.minimum(jnp.maximum(min, a), max),
lax.integer_pow_p: _integer_pow,
lax.logistic_p: lambda a: 1 / (1 + jnp.exp(-a)),
}
triton_lowering_rules[lax.logistic_p] = _logistic_lowering_rule
for primitive, fn in _JAX_FN_MAPPING.items():
triton_lowering_rules[primitive] = lower_fun(fn, multiple_results=False)
def _div_lowering_rule(ctx: TritonLoweringRuleContext, a, b):
@ -471,21 +487,6 @@ def _iota_lowering_rule(
triton_lowering_rules[lax.iota_p] = _iota_lowering_rule
def _integer_pow_lowering_rule(ctx: TritonLoweringRuleContext, a, *, y):
if y == 2:
return a.__mul__(a, _builder=ctx.builder)
if y == 3:
return a.__mul__(a.__mul__(a, _builder=ctx.builder), _builder=ctx.builder)
if y == -2:
one_ = tl.core._to_tensor(1.0, ctx.builder)
a_sq = a.__mul__(a, _builder=ctx.builder)
return one_.__truediv__(a_sq, _builder=ctx.builder)
return tl.math.pow(a, y, _builder=ctx.builder)
triton_lowering_rules[lax.integer_pow_p] = _integer_pow_lowering_rule
def _convert_element_type_lowering_rule(
ctx: TritonLoweringRuleContext, a, *, new_dtype, weak_type
):