mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Migrated remaining operations from the math namespace to lower directly to Triton IR
PiperOrigin-RevId: 602390761
This commit is contained in:
parent
07f8f700ca
commit
fad3e749a1
@ -753,7 +753,7 @@ uint64 = tl.core.uint64
|
||||
def _bool_block_like(v: tensor) -> block_type:
|
||||
if not v.type.is_block():
|
||||
return int1
|
||||
return tl.block_type(int1, v.type.shape)
|
||||
return block_type(int1, v.type.shape)
|
||||
|
||||
|
||||
def wrap_with_builder(fn):
|
||||
@ -868,7 +868,7 @@ class tensor(tl.core.tensor):
|
||||
def program_id(axis: int) -> tensor:
|
||||
if axis not in range(3):
|
||||
raise ValueError(f"axis must be in [0, 3), but got: {axis}")
|
||||
return tensor(tt_dialect.get_program_id(axis), tl.int32)
|
||||
return tensor(tt_dialect.get_program_id(axis), int32)
|
||||
|
||||
|
||||
load = wrap_with_builder(tl.core.load)
|
||||
@ -882,7 +882,7 @@ def arange(start: int, end: int) -> tensor:
|
||||
)
|
||||
if max(start, end) >= 2**32:
|
||||
raise ValueError("start and end must fit in int32")
|
||||
ty = block_type(tl.int32, [end - start])
|
||||
ty = block_type(int32, [end - start])
|
||||
ir_ty = ir.RankedTensorType.get(
|
||||
[end - start], ir.IntegerType.get_signless(32)
|
||||
)
|
||||
@ -1115,11 +1115,15 @@ def libdevice_extern_elementwise(
|
||||
table: Mapping[tuple[dtype, ...], tuple[str, dtype]],
|
||||
is_pure: bool = True,
|
||||
):
|
||||
def inner(arg: tensor):
|
||||
|
||||
def inner(arg: tensor, *rest: tensor) -> tensor:
|
||||
args = (arg, *rest)
|
||||
assert all(other.shape == arg.shape for other in rest)
|
||||
key = tuple(arg.dtype for arg in args)
|
||||
try:
|
||||
symbol, dtype = table[(arg.dtype,)]
|
||||
symbol, dtype = table[key]
|
||||
except KeyError:
|
||||
raise NotImplementedError(f"unsupported dtypes: {(arg.dtype,)}") from None
|
||||
raise NotImplementedError(f"unsupported dtypes: {key}") from None
|
||||
|
||||
return_type = dtype
|
||||
if arg.type.is_block():
|
||||
@ -1127,7 +1131,7 @@ def libdevice_extern_elementwise(
|
||||
return tensor(
|
||||
tt_dialect.extern_elementwise(
|
||||
return_type.to_ir(builder.current),
|
||||
[arg.handle],
|
||||
[arg.handle for arg in args],
|
||||
libname="libdevice",
|
||||
libpath=_LIBDEVICE_PATH,
|
||||
symbol=symbol,
|
||||
@ -1140,6 +1144,34 @@ def libdevice_extern_elementwise(
|
||||
|
||||
|
||||
class math:
|
||||
@staticmethod
|
||||
def max(x: tensor, y: tensor) -> tensor:
|
||||
# TODO(slebedev): Consider allowing customizing nan behavior.
|
||||
assert x.shape == y.shape
|
||||
if x.dtype.is_floating():
|
||||
# TODO(slebedev): Triton promotes bfloat16 to float32 and back here.
|
||||
return tensor(arith_dialect.maxnumf(x.handle, y.handle), x.dtype)
|
||||
if not x.dtype.is_int():
|
||||
raise NotImplementedError(f"unsupported dtypes: {x.dtype} and {y.dtype}")
|
||||
elif x.dtype.is_int_signed():
|
||||
return tensor(arith_dialect.maxsi(x.handle, y.handle), x.dtype)
|
||||
else:
|
||||
return tensor(arith_dialect.maxui(x.handle, y.handle), x.dtype)
|
||||
|
||||
@staticmethod
|
||||
def min(x: tensor, y: tensor) -> tensor:
|
||||
# TODO(slebedev): Consider allowing customizing nan behavior.
|
||||
assert x.shape == y.shape
|
||||
if x.dtype.is_floating():
|
||||
# TODO(slebedev): Triton promotes bfloat16 to float32 and back here.
|
||||
return tensor(arith_dialect.minnumf(x.handle, y.handle), x.dtype)
|
||||
if not x.dtype.is_int():
|
||||
raise NotImplementedError(f"unsupported dtypes: {x.dtype} and {y.dtype}")
|
||||
elif x.dtype.is_int_signed():
|
||||
return tensor(arith_dialect.minsi(x.handle, y.handle), x.dtype)
|
||||
else:
|
||||
return tensor(arith_dialect.minui(x.handle, y.handle), x.dtype)
|
||||
|
||||
sin = libdevice_extern_elementwise({
|
||||
(float32,): ("__nv_sinf", float32),
|
||||
(float64,): ("__nv_sin", float64),
|
||||
@ -1235,20 +1267,20 @@ class math:
|
||||
(float32,): ("__nv_fabsf", float32),
|
||||
(float64,): ("__nv_fabs", float64),
|
||||
})
|
||||
max = partial(
|
||||
wrap_with_builder(tl.math.max),
|
||||
propagate_nan=tl.PropagateNan.NONE,
|
||||
)
|
||||
min = partial(
|
||||
wrap_with_builder(tl.math.min),
|
||||
propagate_nan=tl.PropagateNan.NONE,
|
||||
)
|
||||
nextafter = wrap_with_builder(tl.math.nextafter)
|
||||
nextafter = libdevice_extern_elementwise({
|
||||
(float32, float32): ("__nv_nextafterf", float32),
|
||||
(float64, float64): ("__nv_nextafter", float64),
|
||||
})
|
||||
popc = libdevice_extern_elementwise({
|
||||
(int32,): ("__nv_popc", int32),
|
||||
(int64,): ("__nv_popcll", int64),
|
||||
})
|
||||
pow = wrap_with_builder(tl.math.pow)
|
||||
pow = libdevice_extern_elementwise({
|
||||
(float32, int32): ("__nv_powif", float32),
|
||||
(float64, int32): ("__nv_powi", float64),
|
||||
(float32, float32): ("__nv_powf", float32),
|
||||
(float64, float64): ("__nv_pow", float64),
|
||||
})
|
||||
sqrt = libdevice_extern_elementwise({
|
||||
(float32,): ("__nv_sqrtf", float32),
|
||||
(float64,): ("__nv_sqrt", float64),
|
||||
@ -1261,15 +1293,19 @@ class math:
|
||||
|
||||
class semantic:
|
||||
cast = wrap_with_builder(tl.semantic.cast)
|
||||
where = wrap_with_builder(tl.semantic.where)
|
||||
|
||||
@staticmethod
|
||||
def where(cond: tensor, x: tensor, y: tensor) -> tensor:
|
||||
assert cond.shape == x.shape == y.shape
|
||||
return tensor(arith_dialect.select(cond.handle, x.handle, y.handle), x.type)
|
||||
|
||||
@staticmethod
|
||||
def trans(x: tensor) -> tensor:
|
||||
if len(x.shape) != 2:
|
||||
raise NotImplementedError(f"unsupported shape: {x.shape}")
|
||||
return tl.tensor(
|
||||
return tensor(
|
||||
tt_dialect.trans(x.handle),
|
||||
tl.block_type(x.dtype, [*reversed(x.shape)]),
|
||||
block_type(x.dtype, [*reversed(x.shape)]),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -1339,7 +1375,7 @@ class semantic:
|
||||
y = semantic.cast(y, float32)
|
||||
if x.dtype.is_floating():
|
||||
assert y.dtype.is_floating()
|
||||
return tl.tensor(arith_dialect.divf(x.handle, y.handle), x.type)
|
||||
return tensor(arith_dialect.divf(x.handle, y.handle), x.type)
|
||||
raise NotImplementedError(f"unsupported dtypes: {x.dtype} and {y.dtype}")
|
||||
|
||||
@staticmethod
|
||||
@ -1360,27 +1396,27 @@ class semantic:
|
||||
|
||||
@staticmethod
|
||||
def and_(x: tensor, y: tensor) -> tensor:
|
||||
return tl.tensor(arith_dialect.andi(x.handle, y.handle), x.type)
|
||||
return tensor(arith_dialect.andi(x.handle, y.handle), x.type)
|
||||
|
||||
@staticmethod
|
||||
def or_(x: tensor, y: tensor) -> tensor:
|
||||
return tl.tensor(arith_dialect.ori(x.handle, y.handle), x.type)
|
||||
return tensor(arith_dialect.ori(x.handle, y.handle), x.type)
|
||||
|
||||
@staticmethod
|
||||
def xor_(x: tensor, y: tensor) -> tensor:
|
||||
return tl.tensor(arith_dialect.xori(x.handle, y.handle), x.type)
|
||||
return tensor(arith_dialect.xori(x.handle, y.handle), x.type)
|
||||
|
||||
@staticmethod
|
||||
def lshr(x: tensor, y: tensor) -> tensor:
|
||||
return tl.tensor(arith_dialect.shrui(x.handle, y.handle), x.type)
|
||||
return tensor(arith_dialect.shrui(x.handle, y.handle), x.type)
|
||||
|
||||
@staticmethod
|
||||
def ashr(x: tensor, y: tensor) -> tensor:
|
||||
return tl.tensor(arith_dialect.shrsi(x.handle, y.handle), x.type)
|
||||
return tensor(arith_dialect.shrsi(x.handle, y.handle), x.type)
|
||||
|
||||
@staticmethod
|
||||
def shl(x: tensor, y: tensor) -> tensor:
|
||||
return tl.tensor(arith_dialect.shli(x.handle, y.handle), x.type)
|
||||
return tensor(arith_dialect.shli(x.handle, y.handle), x.type)
|
||||
|
||||
@staticmethod
|
||||
def equal(x: tensor, y: tensor) -> tensor:
|
||||
|
Loading…
x
Reference in New Issue
Block a user