mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[pallas:gpu] Lower more complex primitives using JAX functions in terms of more basic primitives.
PiperOrigin-RevId: 575883386
This commit is contained in:
parent
8fa287e4e7
commit
b61af5a104
@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
import dataclasses
|
||||
import functools
|
||||
import operator
|
||||
from typing import Any, Dict, Sequence, Tuple
|
||||
from typing import Any, Callable, Dict, Sequence, Tuple
|
||||
import zlib
|
||||
|
||||
import jax
|
||||
@ -432,23 +432,39 @@ for primitive, fn in _TRITON_FN_MAPPING.items():
|
||||
triton_lowering_rules[primitive] = rule
|
||||
|
||||
|
||||
def _clamp_lowering_rule(ctx: TritonLoweringRuleContext, min, operand, max):
|
||||
operand = tl.math.max(operand, min, _builder=ctx.builder)
|
||||
return tl.math.min(operand, max, _builder=ctx.builder)
|
||||
def _integer_pow(a, *, y):
|
||||
if y == 2:
|
||||
return a * a
|
||||
if y == 3:
|
||||
return a * a * a
|
||||
if y == -2:
|
||||
return 1.0 / (a * a)
|
||||
return jax.lax.pow(a, y)
|
||||
|
||||
|
||||
triton_lowering_rules[lax.clamp_p] = _clamp_lowering_rule
|
||||
def lower_fun(
|
||||
fun: Callable[..., Any], *, multiple_results: bool
|
||||
) -> Callable[..., Any]:
|
||||
fn = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
|
||||
|
||||
def f_lowered(ctx: TritonLoweringRuleContext, *args, **params):
|
||||
wrapped_fun = lu.wrap_init(fn, params)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
|
||||
jaxpr = jax_core.ClosedJaxpr(jaxpr, consts)
|
||||
out = _closed_call_lowering_rule(ctx, *args, call_jaxpr=jaxpr)
|
||||
return out if multiple_results else out[0]
|
||||
|
||||
return f_lowered
|
||||
|
||||
|
||||
def _logistic_lowering_rule(ctx: TritonLoweringRuleContext, a):
|
||||
one_ = tl.core._to_tensor(1.0, ctx.builder)
|
||||
x = tl.exp(a.__neg__(_builder=ctx.builder), _builder=ctx.builder)
|
||||
x = x.__add__(one_, _builder=ctx.builder)
|
||||
x = one_.__truediv__(x, _builder=ctx.builder)
|
||||
return x
|
||||
_JAX_FN_MAPPING = {
|
||||
lax.clamp_p: lambda min, a, max: jnp.minimum(jnp.maximum(min, a), max),
|
||||
lax.integer_pow_p: _integer_pow,
|
||||
lax.logistic_p: lambda a: 1 / (1 + jnp.exp(-a)),
|
||||
}
|
||||
|
||||
|
||||
triton_lowering_rules[lax.logistic_p] = _logistic_lowering_rule
|
||||
for primitive, fn in _JAX_FN_MAPPING.items():
|
||||
triton_lowering_rules[primitive] = lower_fun(fn, multiple_results=False)
|
||||
|
||||
|
||||
def _div_lowering_rule(ctx: TritonLoweringRuleContext, a, b):
|
||||
@ -471,21 +487,6 @@ def _iota_lowering_rule(
|
||||
triton_lowering_rules[lax.iota_p] = _iota_lowering_rule
|
||||
|
||||
|
||||
def _integer_pow_lowering_rule(ctx: TritonLoweringRuleContext, a, *, y):
|
||||
if y == 2:
|
||||
return a.__mul__(a, _builder=ctx.builder)
|
||||
if y == 3:
|
||||
return a.__mul__(a.__mul__(a, _builder=ctx.builder), _builder=ctx.builder)
|
||||
if y == -2:
|
||||
one_ = tl.core._to_tensor(1.0, ctx.builder)
|
||||
a_sq = a.__mul__(a, _builder=ctx.builder)
|
||||
return one_.__truediv__(a_sq, _builder=ctx.builder)
|
||||
return tl.math.pow(a, y, _builder=ctx.builder)
|
||||
|
||||
|
||||
triton_lowering_rules[lax.integer_pow_p] = _integer_pow_lowering_rule
|
||||
|
||||
|
||||
def _convert_element_type_lowering_rule(
|
||||
ctx: TritonLoweringRuleContext, a, *, new_dtype, weak_type
|
||||
):
|
||||
|
Loading…
x
Reference in New Issue
Block a user