Migrated a subset of triton.compat to directly use IR builders

PiperOrigin-RevId: 598826331
This commit is contained in:
Sergei Lebedev 2024-01-16 06:45:31 -08:00 committed by jax authors
parent ab8eb896d7
commit af49b01e1f

View File

@ -253,10 +253,6 @@ class builder:
) -> scf_dialect.ConditionOp:
return scf_dialect.ConditionOp(cond, args)
def create_make_range(self, start: int, end: int) -> scf_dialect.MakeRangeOp:
result = ir.RankedTensorType.get([end - start], self.get_int32_ty())
return tt_dialect.make_range(result, start, end)
def create_fp_to_fp(self, src: ir.Value, dst_type: ir.Type) -> ir.Value:
return tt_dialect.fp_to_fp(dst_type, src)
@ -571,19 +567,6 @@ class builder:
evict=eviction_policy,
)
def create_reshape(
self, arg: ir.Value, shape: Sequence[int], allow_reorder: bool
) -> ir.Value:
assert ir.RankedTensorType.isinstance(arg.type)
arg_type = ir.RankedTensorType(arg.type)
result_type = ir.RankedTensorType.get(
shape, arg_type.element_type, arg_type.encoding
)
return tt_dialect.reshape(result_type, arg, allow_reorder)
def create_expand_dims(self, arg: ir.Value, axis: int) -> ir.Value:
return tt_dialect.expand_dims(arg, axis)
def create_cat(self, lhs: ir.Value, rhs: ir.Value) -> ir.Value:
assert ir.RankedTensorType.isinstance(lhs.type)
assert ir.RankedTensorType.isinstance(rhs.type)
@ -667,10 +650,6 @@ class builder:
return_type, args, lib_name, lib_path, symbol, is_pure
)
def create_get_program_id(self, axis: int) -> ir.Value:
assert 0 <= axis < 3
return tt_dialect.get_program_id(axis)
def create_get_num_programs(self, axis: int) -> ir.Value:
return tt_dialect.get_num_programs(axis)
@ -684,27 +663,6 @@ class builder:
) -> ir.Value:
return tt_dialect.dot(a, b, c, allow_tf32, max_num_imprecise_acc)
def create_exp(self, val: ir.Value) -> ir.Value:
return math_dialect.exp(val)
def create_cos(self, val: ir.Value) -> ir.Value:
return math_dialect.cos(val)
def create_sin(self, val: ir.Value) -> ir.Value:
return math_dialect.sin(val)
def create_log(self, val: ir.Value) -> ir.Value:
return math_dialect.log(val)
def create_sqrt(self, val: ir.Value) -> ir.Value:
return math_dialect.sqrt(val)
def create_fabs(self, val: ir.Value) -> ir.Value:
return math_dialect.absf(val)
def create_iabs(self, val: ir.Value) -> ir.Value:
return math_dialect.absi(val)
def create_reduce(
self, operands: Sequence[ir.Value], axis: int
) -> tt_dialect.ReduceOp:
@ -898,20 +856,101 @@ class tensor(tl.core.tensor):
def __eq__(self, other):
return semantic.equal(self, _to_tensor(other))
__getitem__ = wrap_with_builder(tl.tensor.__getitem__)
def __getitem__(self, slices) -> tensor:
if isinstance(slices, (slice, constexpr)):
slices = [slices]
t = self
for axis, s in enumerate(slices):
if s is None or isinstance(s, constexpr) and s.value is None:
t = expand_dims(t, axis)
elif (
isinstance(s, slice)
and s.start is s.stop is s.step is None
):
pass
else:
raise IndexError(f"unsupported tensor index: {s}")
return t
to = wrap_with_builder(tl.tensor.to)
program_id = wrap_with_builder(tl.core.program_id)
def program_id(axis: int) -> tensor:
if axis not in range(3):
raise ValueError(f"axis must be in [0, 3), but got: {axis}")
return tensor(tt_dialect.get_program_id(axis), tl.int32)
load = wrap_with_builder(tl.core.load)
store = wrap_with_builder(tl.core.store)
arange = wrap_with_builder(tl.core.arange)
broadcast_to = wrap_with_builder(tl.core.broadcast_to)
expand_dims = wrap_with_builder(tl.core.expand_dims)
reshape = wrap_with_builder(tl.core.reshape)
def arange(start: int, end: int) -> tensor:
if end <= start:
raise ValueError(
f"end must be greater than start, but got: {end} <= {start}"
)
if max(start, end) >= 2**32:
raise ValueError("start and end must fit in int32")
ty = block_type(tl.int32, [end - start])
ir_ty = ir.RankedTensorType.get(
[end - start], ir.IntegerType.get_signless(32)
)
return tensor(tt_dialect.make_range(ir_ty, start, end), ty)
def broadcast_to(x: object, shape: Sequence[int | constexpr]) -> tensor:
x = _to_tensor(x)
if not x.type.is_block():
return splat(x, shape)
elif x.shape == shape:
return x
shape = [dim.__index__() for dim in shape]
x_ir_type = ir.RankedTensorType(x.handle.type)
result_ir_type = ir.RankedTensorType.get(
shape, x_ir_type.element_type, x_ir_type.encoding
)
return tensor(
tt_dialect.broadcast(result_ir_type, x.handle),
block_type(x.dtype, shape),
)
def splat(x: object, shape: Sequence[int | constexpr]) -> tensor:
x = _to_tensor(x)
if x.type.is_block():
raise ValueError("cannot splat a block tensor")
if len(shape) == 0:
return x
shape = [dim.__index__() for dim in shape]
result_ir_type = ir.RankedTensorType.get(shape, x.handle.type)
return tensor(
tt_dialect.splat(result_ir_type, x.handle), block_type(x.dtype, shape)
)
def expand_dims(x: object, axis: int) -> tensor:
x = _to_tensor(x)
dst_shape = [dim.__index__() for dim in x.shape]
dst_shape.insert(axis, 1)
if not x.type.is_block():
return splat(input, dst_shape)
return tensor(
tt_dialect.expand_dims(x.handle, axis),
block_type(x.dtype, dst_shape),
)
def reshape(x: tensor, dst_shape: Sequence[int]) -> tensor:
x_ir_type = ir.RankedTensorType(x.handle.type)
result_ir_type = ir.RankedTensorType.get(
dst_shape, x_ir_type.element_type, x_ir_type.encoding
)
return tensor(
tt_dialect.reshape(result_ir_type, x.handle, allow_reorder=False),
block_type(x.dtype, dst_shape),
)
dot = wrap_with_builder(tl.core.dot)
@ -924,15 +963,56 @@ atomic_or = wrap_with_builder(tl.core.atomic_or)
atomic_xor = wrap_with_builder(tl.core.atomic_xor)
atomic_cas = wrap_with_builder(tl.atomic_cas)
abs = wrap_with_builder(tl.abs)
exp = wrap_with_builder(tl.exp)
log = wrap_with_builder(tl.log)
sqrt = wrap_with_builder(tl.sqrt)
sin = wrap_with_builder(tl.sin)
cos = wrap_with_builder(tl.cos)
def abs(x: object) -> tensor:
x = _to_tensor(x)
dtype = x.dtype
if dtype.is_floating():
return tensor(math_dialect.absf(x.handle), x.type)
elif dtype.is_int_signed():
return tensor(math_dialect.absi(x.handle), x.type)
elif dtype.is_int_unsigned():
return x
else:
raise ValueError(f"unsupported dtype: {dtype}")
def multiple_of(x: tensor, values: list[int]) -> tl.tensor:
def exp(x: object) -> tensor:
x = _to_tensor(x)
if x.dtype != float32 and x.dtype != float64:
raise ValueError(f"unsupported dtype: {x.dtype}")
return tensor(math_dialect.exp(x.handle), x.type)
def log(x: object) -> tensor:
x = _to_tensor(x)
if x.dtype != float32 and x.dtype != float64:
raise ValueError(f"unsupported dtype: {x.dtype}")
return tensor(math_dialect.log(x.handle), x.type)
def sqrt(x: object) -> tensor:
x = _to_tensor(x)
if x.dtype != float32 and x.dtype != float64:
raise ValueError(f"unsupported dtype: {x.dtype}")
return tensor(math_dialect.sqrt(x.handle), x.type)
def sin(x: object) -> tensor:
x = _to_tensor(x)
if x.dtype != float32 and x.dtype != float64:
raise ValueError(f"unsupported dtype: {x.dtype}")
return tensor(math_dialect.sin(x.handle), x.type)
def cos(x: object) -> tensor:
x = _to_tensor(x)
if x.dtype != float32 and x.dtype != float64:
raise ValueError(f"unsupported dtype: {x.dtype}")
return tensor(math_dialect.cos(x.handle), x.type)
def multiple_of(x: tensor, values: Sequence[int]) -> tl.tensor:
assert max(1, len(x.shape)) == len(values)
set_attr(
x.handle,
@ -944,7 +1024,7 @@ def multiple_of(x: tensor, values: list[int]) -> tl.tensor:
return x
def max_contiguous(x: tensor, values: list[int]) -> tl.tensor:
def max_contiguous(x: tensor, values: Sequence[int]) -> tl.tensor:
assert len(x.shape) == len(values)
set_attr(
x.handle,
@ -1010,7 +1090,6 @@ class semantic:
ashr = wrap_with_builder(tl.semantic.ashr)
cast = wrap_with_builder(tl.semantic.cast)
equal = wrap_with_builder(tl.semantic.equal)
expand_dims = wrap_with_builder(tl.semantic.expand_dims)
floordiv = wrap_with_builder(tl.semantic.floordiv)
greater_equal = wrap_with_builder(tl.semantic.greater_equal)
greater_than = wrap_with_builder(tl.semantic.greater_than)