mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Lower a subset of math primitives directly to Triton IR
Note that all primitives are now lowered to libdevice calls. Previously, some of them were lowered to the MLIR arith dialect, and some to libdevice calls, without any apparent reason for doing so. PiperOrigin-RevId: 601259707
This commit is contained in:
parent
cfb6250158
commit
f15cad4651
@ -463,19 +463,19 @@ triton_lowering_rules[lax.cumsum_p] = _cumsum_lowering_rule
|
||||
_TRITON_FN_MAPPING = {
|
||||
# Unary ops.
|
||||
lax.neg_p: tc.semantic.minus,
|
||||
lax.abs_p: tc.abs,
|
||||
lax.abs_p: tc.math.abs,
|
||||
lax.ceil_p: tc.math.ceil,
|
||||
lax.floor_p: tc.math.floor,
|
||||
lax.exp_p: tc.exp,
|
||||
lax.exp_p: tc.math.exp,
|
||||
lax.exp2_p: tc.math.exp2,
|
||||
lax.expm1_p: tc.math.expm1,
|
||||
lax.log_p: tc.log,
|
||||
lax.log_p: tc.math.log,
|
||||
lax.log1p_p: tc.math.log1p,
|
||||
lax.sqrt_p: tc.sqrt,
|
||||
lax.sqrt_p: tc.math.sqrt,
|
||||
lax.cbrt_p: tc.math.cbrt,
|
||||
lax.rsqrt_p: tc.math.rsqrt,
|
||||
lax.sin_p: tc.sin,
|
||||
lax.cos_p: tc.cos,
|
||||
lax.sin_p: tc.math.sin,
|
||||
lax.cos_p: tc.math.cos,
|
||||
lax.tan_p: tc.math.tan,
|
||||
lax.asin_p: tc.math.asin,
|
||||
lax.acos_p: tc.math.acos,
|
||||
|
@ -19,7 +19,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Mapping, Sequence
|
||||
from functools import partial, wraps
|
||||
import threading
|
||||
|
||||
@ -1051,22 +1051,133 @@ def set_attr(v: ir.Value, name: str, attr: ir.Attribute) -> None:
|
||||
op.attributes[name] = attr
|
||||
|
||||
|
||||
_LIBDEVICE_PATH = tl.math.libdevice_path()
|
||||
|
||||
|
||||
def libdevice_extern_elementwise(
|
||||
table: Mapping[tuple[dtype, ...], tuple[str, dtype]],
|
||||
is_pure: bool = True,
|
||||
):
|
||||
def inner(arg: tensor):
|
||||
try:
|
||||
symbol, dtype = table[(arg.dtype,)]
|
||||
except KeyError:
|
||||
raise NotImplementedError(f"unsupported dtypes: {(arg.dtype,)}") from None
|
||||
|
||||
return_type = dtype
|
||||
if arg.type.is_block():
|
||||
return_type = block_type(dtype, arg.shape)
|
||||
return tensor(
|
||||
tt_dialect.extern_elementwise(
|
||||
return_type.to_ir(builder.current),
|
||||
[arg.handle],
|
||||
libname="libdevice",
|
||||
libpath=_LIBDEVICE_PATH,
|
||||
symbol=symbol,
|
||||
pure=is_pure,
|
||||
),
|
||||
return_type,
|
||||
)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
class math:
|
||||
acos = wrap_with_builder(tl.math.acos)
|
||||
acosh = wrap_with_builder(tl.math.acosh)
|
||||
asin = wrap_with_builder(tl.math.asin)
|
||||
asinh = wrap_with_builder(tl.math.asinh)
|
||||
atan = wrap_with_builder(tl.math.atan)
|
||||
atan2 = wrap_with_builder(tl.math.atan2)
|
||||
atanh = wrap_with_builder(tl.math.atanh)
|
||||
cbrt = wrap_with_builder(tl.math.cbrt)
|
||||
ceil = wrap_with_builder(tl.math.ceil)
|
||||
clz = wrap_with_builder(tl.math.clz)
|
||||
cosh = wrap_with_builder(tl.math.cosh)
|
||||
exp2 = wrap_with_builder(tl.math.exp2)
|
||||
expm1 = wrap_with_builder(tl.math.expm1)
|
||||
floor = wrap_with_builder(tl.math.floor)
|
||||
log1p = wrap_with_builder(tl.math.log1p)
|
||||
sin = libdevice_extern_elementwise({
|
||||
(float32,): ("__nv_sinf", float32),
|
||||
(float64,): ("__nv_sin", float64),
|
||||
})
|
||||
cos = libdevice_extern_elementwise({
|
||||
(float32,): ("__nv_cosf", float32),
|
||||
(float64,): ("__nv_cos", float64),
|
||||
})
|
||||
tan = libdevice_extern_elementwise({
|
||||
(float32,): ("__nv_tanf", float32),
|
||||
(float64,): ("__nv_tan", float64),
|
||||
})
|
||||
asin = libdevice_extern_elementwise({
|
||||
(float32,): ("__nv_asinf", float32),
|
||||
(float64,): ("__nv_asin", float64),
|
||||
})
|
||||
acos = libdevice_extern_elementwise({
|
||||
(float32,): ("__nv_acosf", float32),
|
||||
(float64,): ("__nv_acos", float64),
|
||||
})
|
||||
atan = libdevice_extern_elementwise({
|
||||
(float32,): ("__nv_atanf", float32),
|
||||
(float64,): ("__nv_atan", float64),
|
||||
})
|
||||
atan2 = libdevice_extern_elementwise({
|
||||
(float32,): ("__nv_atan2f", float32),
|
||||
(float64,): ("__nv_atan2", float64),
|
||||
})
|
||||
sinh = libdevice_extern_elementwise({
|
||||
(float32,): ("__nv_sinhf", float32),
|
||||
(float64,): ("__nv_sinh", float64),
|
||||
})
|
||||
cosh = libdevice_extern_elementwise({
|
||||
(float32,): ("__nv_coshf", float32),
|
||||
(float64,): ("__nv_cosh", float64),
|
||||
})
|
||||
tanh = libdevice_extern_elementwise({
|
||||
(float32,): ("__nv_tanhf", float32),
|
||||
(float64,): ("__nv_tanh", float64),
|
||||
})
|
||||
asinh = libdevice_extern_elementwise({
|
||||
(float32,): ("__nv_asinhf", float32),
|
||||
(float64,): ("__nv_asinh", float64),
|
||||
})
|
||||
acosh = libdevice_extern_elementwise({
|
||||
(float32,): ("__nv_acosf", float32),
|
||||
(float64,): ("__nv_acosh", float64),
|
||||
})
|
||||
atanh = libdevice_extern_elementwise({
|
||||
(float32,): ("__nv_atanhf", float32),
|
||||
(float64,): ("__nv_atanh", float64),
|
||||
})
|
||||
|
||||
cbrt = libdevice_extern_elementwise({
|
||||
(float32,): ("__nv_cbrtf", float32),
|
||||
(float64,): ("__nv_cbrt", float64),
|
||||
})
|
||||
clz = libdevice_extern_elementwise({
|
||||
(int32,): ("__nv_clz", int32),
|
||||
(int64,): ("__nv_clzll", int64),
|
||||
})
|
||||
exp = libdevice_extern_elementwise({
|
||||
(float32,): ("__nv_expf", float32),
|
||||
(float64,): ("__nv_exp", float64),
|
||||
})
|
||||
exp2 = libdevice_extern_elementwise({
|
||||
(float32,): ("__nv_exp2f", float32),
|
||||
(float64,): ("__nv_exp2", float64),
|
||||
})
|
||||
expm1 = libdevice_extern_elementwise({
|
||||
(float32,): ("__nv_expm1f", float32),
|
||||
(float64,): ("__nv_expm1", float64),
|
||||
})
|
||||
log = libdevice_extern_elementwise({
|
||||
(float32,): ("__nv_logf", float32),
|
||||
(float64,): ("__nv_log", float64),
|
||||
})
|
||||
log1p = libdevice_extern_elementwise({
|
||||
(float32,): ("__nv_log1pf", float32),
|
||||
(float64,): ("__nv_log1p", float64),
|
||||
})
|
||||
floor = libdevice_extern_elementwise({
|
||||
(float32,): ("__nv_floorf", float32),
|
||||
(float64,): ("__nv_floor", float64),
|
||||
})
|
||||
ceil = libdevice_extern_elementwise({
|
||||
(float32,): ("__nv_ceilf", float32),
|
||||
(float64,): ("__nv_ceil", float64),
|
||||
})
|
||||
abs = libdevice_extern_elementwise({
|
||||
(int32,): ("__nv_abs", int32),
|
||||
(int64,): ("__nv_llabs", int64),
|
||||
(float32,): ("__nv_fabsf", float32),
|
||||
(float64,): ("__nv_fabs", float64),
|
||||
})
|
||||
max = partial(
|
||||
wrap_with_builder(tl.math.max),
|
||||
propagate_nan=tl.PropagateNan.NONE,
|
||||
@ -1076,12 +1187,19 @@ class math:
|
||||
propagate_nan=tl.PropagateNan.NONE,
|
||||
)
|
||||
nextafter = wrap_with_builder(tl.math.nextafter)
|
||||
popc = wrap_with_builder(tl.math.popc)
|
||||
popc = libdevice_extern_elementwise({
|
||||
(int32,): ("__nv_popc", int32),
|
||||
(int64,): ("__nv_popcll", int64),
|
||||
})
|
||||
pow = wrap_with_builder(tl.math.pow)
|
||||
rsqrt = wrap_with_builder(tl.math.rsqrt)
|
||||
sinh = wrap_with_builder(tl.math.sinh)
|
||||
tan = wrap_with_builder(tl.math.tan)
|
||||
tanh = wrap_with_builder(tl.math.tanh)
|
||||
sqrt = libdevice_extern_elementwise({
|
||||
(float32,): ("__nv_sqrtf", float32),
|
||||
(float64,): ("__nv_sqrt", float64),
|
||||
})
|
||||
rsqrt = libdevice_extern_elementwise({
|
||||
(float32,): ("__nv_rsqrtf", float32),
|
||||
(float64,): ("__nv_rsqrt", float64),
|
||||
})
|
||||
|
||||
|
||||
class semantic:
|
||||
|
Loading…
x
Reference in New Issue
Block a user