Migrated remaining operations from the math namespace to lower directly to Triton IR

PiperOrigin-RevId: 602390761
This commit is contained in:
Sergei Lebedev 2024-01-29 08:09:10 -08:00 committed by jax authors
parent 07f8f700ca
commit fad3e749a1

View File

@ -753,7 +753,7 @@ uint64 = tl.core.uint64
def _bool_block_like(v: tensor) -> block_type:
if not v.type.is_block():
return int1
return tl.block_type(int1, v.type.shape)
return block_type(int1, v.type.shape)
def wrap_with_builder(fn):
@ -868,7 +868,7 @@ class tensor(tl.core.tensor):
def program_id(axis: int) -> tensor:
if axis not in range(3):
raise ValueError(f"axis must be in [0, 3), but got: {axis}")
return tensor(tt_dialect.get_program_id(axis), tl.int32)
return tensor(tt_dialect.get_program_id(axis), int32)
load = wrap_with_builder(tl.core.load)
@ -882,7 +882,7 @@ def arange(start: int, end: int) -> tensor:
)
if max(start, end) >= 2**32:
raise ValueError("start and end must fit in int32")
ty = block_type(tl.int32, [end - start])
ty = block_type(int32, [end - start])
ir_ty = ir.RankedTensorType.get(
[end - start], ir.IntegerType.get_signless(32)
)
@ -1115,11 +1115,15 @@ def libdevice_extern_elementwise(
table: Mapping[tuple[dtype, ...], tuple[str, dtype]],
is_pure: bool = True,
):
def inner(arg: tensor):
def inner(arg: tensor, *rest: tensor) -> tensor:
args = (arg, *rest)
assert all(other.shape == arg.shape for other in rest)
key = tuple(arg.dtype for arg in args)
try:
symbol, dtype = table[(arg.dtype,)]
symbol, dtype = table[key]
except KeyError:
raise NotImplementedError(f"unsupported dtypes: {(arg.dtype,)}") from None
raise NotImplementedError(f"unsupported dtypes: {key}") from None
return_type = dtype
if arg.type.is_block():
@ -1127,7 +1131,7 @@ def libdevice_extern_elementwise(
return tensor(
tt_dialect.extern_elementwise(
return_type.to_ir(builder.current),
[arg.handle],
[arg.handle for arg in args],
libname="libdevice",
libpath=_LIBDEVICE_PATH,
symbol=symbol,
@ -1140,6 +1144,34 @@ def libdevice_extern_elementwise(
class math:
@staticmethod
def max(x: tensor, y: tensor) -> tensor:
# TODO(slebedev): Consider allowing customizing nan behavior.
assert x.shape == y.shape
if x.dtype.is_floating():
# TODO(slebedev): Triton promotes bfloat16 to float32 and back here.
return tensor(arith_dialect.maxnumf(x.handle, y.handle), x.dtype)
if not x.dtype.is_int():
raise NotImplementedError(f"unsupported dtypes: {x.dtype} and {y.dtype}")
elif x.dtype.is_int_signed():
return tensor(arith_dialect.maxsi(x.handle, y.handle), x.dtype)
else:
return tensor(arith_dialect.maxui(x.handle, y.handle), x.dtype)
@staticmethod
def min(x: tensor, y: tensor) -> tensor:
# TODO(slebedev): Consider allowing customizing nan behavior.
assert x.shape == y.shape
if x.dtype.is_floating():
# TODO(slebedev): Triton promotes bfloat16 to float32 and back here.
return tensor(arith_dialect.minnumf(x.handle, y.handle), x.dtype)
if not x.dtype.is_int():
raise NotImplementedError(f"unsupported dtypes: {x.dtype} and {y.dtype}")
elif x.dtype.is_int_signed():
return tensor(arith_dialect.minsi(x.handle, y.handle), x.dtype)
else:
return tensor(arith_dialect.minui(x.handle, y.handle), x.dtype)
sin = libdevice_extern_elementwise({
(float32,): ("__nv_sinf", float32),
(float64,): ("__nv_sin", float64),
@ -1235,20 +1267,20 @@ class math:
(float32,): ("__nv_fabsf", float32),
(float64,): ("__nv_fabs", float64),
})
max = partial(
wrap_with_builder(tl.math.max),
propagate_nan=tl.PropagateNan.NONE,
)
min = partial(
wrap_with_builder(tl.math.min),
propagate_nan=tl.PropagateNan.NONE,
)
nextafter = wrap_with_builder(tl.math.nextafter)
nextafter = libdevice_extern_elementwise({
(float32, float32): ("__nv_nextafterf", float32),
(float64, float64): ("__nv_nextafter", float64),
})
popc = libdevice_extern_elementwise({
(int32,): ("__nv_popc", int32),
(int64,): ("__nv_popcll", int64),
})
pow = wrap_with_builder(tl.math.pow)
pow = libdevice_extern_elementwise({
(float32, int32): ("__nv_powif", float32),
(float64, int32): ("__nv_powi", float64),
(float32, float32): ("__nv_powf", float32),
(float64, float64): ("__nv_pow", float64),
})
sqrt = libdevice_extern_elementwise({
(float32,): ("__nv_sqrtf", float32),
(float64,): ("__nv_sqrt", float64),
@ -1261,15 +1293,19 @@ class math:
class semantic:
cast = wrap_with_builder(tl.semantic.cast)
where = wrap_with_builder(tl.semantic.where)
@staticmethod
def where(cond: tensor, x: tensor, y: tensor) -> tensor:
assert cond.shape == x.shape == y.shape
return tensor(arith_dialect.select(cond.handle, x.handle, y.handle), x.type)
@staticmethod
def trans(x: tensor) -> tensor:
if len(x.shape) != 2:
raise NotImplementedError(f"unsupported shape: {x.shape}")
return tl.tensor(
return tensor(
tt_dialect.trans(x.handle),
tl.block_type(x.dtype, [*reversed(x.shape)]),
block_type(x.dtype, [*reversed(x.shape)]),
)
@staticmethod
@ -1339,7 +1375,7 @@ class semantic:
y = semantic.cast(y, float32)
if x.dtype.is_floating():
assert y.dtype.is_floating()
return tl.tensor(arith_dialect.divf(x.handle, y.handle), x.type)
return tensor(arith_dialect.divf(x.handle, y.handle), x.type)
raise NotImplementedError(f"unsupported dtypes: {x.dtype} and {y.dtype}")
@staticmethod
@ -1360,27 +1396,27 @@ class semantic:
@staticmethod
def and_(x: tensor, y: tensor) -> tensor:
return tl.tensor(arith_dialect.andi(x.handle, y.handle), x.type)
return tensor(arith_dialect.andi(x.handle, y.handle), x.type)
@staticmethod
def or_(x: tensor, y: tensor) -> tensor:
return tl.tensor(arith_dialect.ori(x.handle, y.handle), x.type)
return tensor(arith_dialect.ori(x.handle, y.handle), x.type)
@staticmethod
def xor_(x: tensor, y: tensor) -> tensor:
return tl.tensor(arith_dialect.xori(x.handle, y.handle), x.type)
return tensor(arith_dialect.xori(x.handle, y.handle), x.type)
@staticmethod
def lshr(x: tensor, y: tensor) -> tensor:
return tl.tensor(arith_dialect.shrui(x.handle, y.handle), x.type)
return tensor(arith_dialect.shrui(x.handle, y.handle), x.type)
@staticmethod
def ashr(x: tensor, y: tensor) -> tensor:
return tl.tensor(arith_dialect.shrsi(x.handle, y.handle), x.type)
return tensor(arith_dialect.shrsi(x.handle, y.handle), x.type)
@staticmethod
def shl(x: tensor, y: tensor) -> tensor:
return tl.tensor(arith_dialect.shli(x.handle, y.handle), x.type)
return tensor(arith_dialect.shli(x.handle, y.handle), x.type)
@staticmethod
def equal(x: tensor, y: tensor) -> tensor: