Lower a subset of math primitives directly to Triton IR

Note that all primitives are now lowered to libdevice calls. Previously,
some of them were lowered to the MLIR arith dialect, and some to libdevice
calls, without any apparent reason for doing so.

PiperOrigin-RevId: 601259707
This commit is contained in:
Sergei Lebedev 2024-01-24 15:54:31 -08:00 committed by jax authors
parent cfb6250158
commit f15cad4651
2 changed files with 145 additions and 27 deletions

View File

@ -463,19 +463,19 @@ triton_lowering_rules[lax.cumsum_p] = _cumsum_lowering_rule
_TRITON_FN_MAPPING = {
# Unary ops.
lax.neg_p: tc.semantic.minus,
lax.abs_p: tc.abs,
lax.abs_p: tc.math.abs,
lax.ceil_p: tc.math.ceil,
lax.floor_p: tc.math.floor,
lax.exp_p: tc.exp,
lax.exp_p: tc.math.exp,
lax.exp2_p: tc.math.exp2,
lax.expm1_p: tc.math.expm1,
lax.log_p: tc.log,
lax.log_p: tc.math.log,
lax.log1p_p: tc.math.log1p,
lax.sqrt_p: tc.sqrt,
lax.sqrt_p: tc.math.sqrt,
lax.cbrt_p: tc.math.cbrt,
lax.rsqrt_p: tc.math.rsqrt,
lax.sin_p: tc.sin,
lax.cos_p: tc.cos,
lax.sin_p: tc.math.sin,
lax.cos_p: tc.math.cos,
lax.tan_p: tc.math.tan,
lax.asin_p: tc.math.asin,
lax.acos_p: tc.math.acos,

View File

@ -19,7 +19,7 @@
from __future__ import annotations
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from functools import partial, wraps
import threading
@ -1051,22 +1051,133 @@ def set_attr(v: ir.Value, name: str, attr: ir.Attribute) -> None:
op.attributes[name] = attr
_LIBDEVICE_PATH = tl.math.libdevice_path()
def libdevice_extern_elementwise(
table: Mapping[tuple[dtype, ...], tuple[str, dtype]],
is_pure: bool = True,
):
def inner(arg: tensor):
try:
symbol, dtype = table[(arg.dtype,)]
except KeyError:
raise NotImplementedError(f"unsupported dtypes: {(arg.dtype,)}") from None
return_type = dtype
if arg.type.is_block():
return_type = block_type(dtype, arg.shape)
return tensor(
tt_dialect.extern_elementwise(
return_type.to_ir(builder.current),
[arg.handle],
libname="libdevice",
libpath=_LIBDEVICE_PATH,
symbol=symbol,
pure=is_pure,
),
return_type,
)
return inner
class math:
acos = wrap_with_builder(tl.math.acos)
acosh = wrap_with_builder(tl.math.acosh)
asin = wrap_with_builder(tl.math.asin)
asinh = wrap_with_builder(tl.math.asinh)
atan = wrap_with_builder(tl.math.atan)
atan2 = wrap_with_builder(tl.math.atan2)
atanh = wrap_with_builder(tl.math.atanh)
cbrt = wrap_with_builder(tl.math.cbrt)
ceil = wrap_with_builder(tl.math.ceil)
clz = wrap_with_builder(tl.math.clz)
cosh = wrap_with_builder(tl.math.cosh)
exp2 = wrap_with_builder(tl.math.exp2)
expm1 = wrap_with_builder(tl.math.expm1)
floor = wrap_with_builder(tl.math.floor)
log1p = wrap_with_builder(tl.math.log1p)
sin = libdevice_extern_elementwise({
(float32,): ("__nv_sinf", float32),
(float64,): ("__nv_sin", float64),
})
cos = libdevice_extern_elementwise({
(float32,): ("__nv_cosf", float32),
(float64,): ("__nv_cos", float64),
})
tan = libdevice_extern_elementwise({
(float32,): ("__nv_tanf", float32),
(float64,): ("__nv_tan", float64),
})
asin = libdevice_extern_elementwise({
(float32,): ("__nv_asinf", float32),
(float64,): ("__nv_asin", float64),
})
acos = libdevice_extern_elementwise({
(float32,): ("__nv_acosf", float32),
(float64,): ("__nv_acos", float64),
})
atan = libdevice_extern_elementwise({
(float32,): ("__nv_atanf", float32),
(float64,): ("__nv_atan", float64),
})
atan2 = libdevice_extern_elementwise({
(float32,): ("__nv_atan2f", float32),
(float64,): ("__nv_atan2", float64),
})
sinh = libdevice_extern_elementwise({
(float32,): ("__nv_sinhf", float32),
(float64,): ("__nv_sinh", float64),
})
cosh = libdevice_extern_elementwise({
(float32,): ("__nv_coshf", float32),
(float64,): ("__nv_cosh", float64),
})
tanh = libdevice_extern_elementwise({
(float32,): ("__nv_tanhf", float32),
(float64,): ("__nv_tanh", float64),
})
asinh = libdevice_extern_elementwise({
(float32,): ("__nv_asinhf", float32),
(float64,): ("__nv_asinh", float64),
})
acosh = libdevice_extern_elementwise({
(float32,): ("__nv_acosf", float32),
(float64,): ("__nv_acosh", float64),
})
atanh = libdevice_extern_elementwise({
(float32,): ("__nv_atanhf", float32),
(float64,): ("__nv_atanh", float64),
})
cbrt = libdevice_extern_elementwise({
(float32,): ("__nv_cbrtf", float32),
(float64,): ("__nv_cbrt", float64),
})
clz = libdevice_extern_elementwise({
(int32,): ("__nv_clz", int32),
(int64,): ("__nv_clzll", int64),
})
exp = libdevice_extern_elementwise({
(float32,): ("__nv_expf", float32),
(float64,): ("__nv_exp", float64),
})
exp2 = libdevice_extern_elementwise({
(float32,): ("__nv_exp2f", float32),
(float64,): ("__nv_exp2", float64),
})
expm1 = libdevice_extern_elementwise({
(float32,): ("__nv_expm1f", float32),
(float64,): ("__nv_expm1", float64),
})
log = libdevice_extern_elementwise({
(float32,): ("__nv_logf", float32),
(float64,): ("__nv_log", float64),
})
log1p = libdevice_extern_elementwise({
(float32,): ("__nv_log1pf", float32),
(float64,): ("__nv_log1p", float64),
})
floor = libdevice_extern_elementwise({
(float32,): ("__nv_floorf", float32),
(float64,): ("__nv_floor", float64),
})
ceil = libdevice_extern_elementwise({
(float32,): ("__nv_ceilf", float32),
(float64,): ("__nv_ceil", float64),
})
abs = libdevice_extern_elementwise({
(int32,): ("__nv_abs", int32),
(int64,): ("__nv_llabs", int64),
(float32,): ("__nv_fabsf", float32),
(float64,): ("__nv_fabs", float64),
})
max = partial(
wrap_with_builder(tl.math.max),
propagate_nan=tl.PropagateNan.NONE,
@ -1076,12 +1187,19 @@ class math:
propagate_nan=tl.PropagateNan.NONE,
)
nextafter = wrap_with_builder(tl.math.nextafter)
popc = wrap_with_builder(tl.math.popc)
popc = libdevice_extern_elementwise({
(int32,): ("__nv_popc", int32),
(int64,): ("__nv_popcll", int64),
})
pow = wrap_with_builder(tl.math.pow)
rsqrt = wrap_with_builder(tl.math.rsqrt)
sinh = wrap_with_builder(tl.math.sinh)
tan = wrap_with_builder(tl.math.tan)
tanh = wrap_with_builder(tl.math.tanh)
sqrt = libdevice_extern_elementwise({
(float32,): ("__nv_sqrtf", float32),
(float64,): ("__nv_sqrt", float64),
})
rsqrt = libdevice_extern_elementwise({
(float32,): ("__nv_rsqrtf", float32),
(float64,): ("__nv_rsqrt", float64),
})
class semantic: