mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Migrated a subset of triton.compat to directly use IR builders
PiperOrigin-RevId: 598826331
This commit is contained in:
parent
ab8eb896d7
commit
af49b01e1f
@ -253,10 +253,6 @@ class builder:
|
||||
) -> scf_dialect.ConditionOp:
|
||||
return scf_dialect.ConditionOp(cond, args)
|
||||
|
||||
def create_make_range(self, start: int, end: int) -> scf_dialect.MakeRangeOp:
|
||||
result = ir.RankedTensorType.get([end - start], self.get_int32_ty())
|
||||
return tt_dialect.make_range(result, start, end)
|
||||
|
||||
def create_fp_to_fp(self, src: ir.Value, dst_type: ir.Type) -> ir.Value:
|
||||
return tt_dialect.fp_to_fp(dst_type, src)
|
||||
|
||||
@ -571,19 +567,6 @@ class builder:
|
||||
evict=eviction_policy,
|
||||
)
|
||||
|
||||
def create_reshape(
|
||||
self, arg: ir.Value, shape: Sequence[int], allow_reorder: bool
|
||||
) -> ir.Value:
|
||||
assert ir.RankedTensorType.isinstance(arg.type)
|
||||
arg_type = ir.RankedTensorType(arg.type)
|
||||
result_type = ir.RankedTensorType.get(
|
||||
shape, arg_type.element_type, arg_type.encoding
|
||||
)
|
||||
return tt_dialect.reshape(result_type, arg, allow_reorder)
|
||||
|
||||
def create_expand_dims(self, arg: ir.Value, axis: int) -> ir.Value:
|
||||
return tt_dialect.expand_dims(arg, axis)
|
||||
|
||||
def create_cat(self, lhs: ir.Value, rhs: ir.Value) -> ir.Value:
|
||||
assert ir.RankedTensorType.isinstance(lhs.type)
|
||||
assert ir.RankedTensorType.isinstance(rhs.type)
|
||||
@ -667,10 +650,6 @@ class builder:
|
||||
return_type, args, lib_name, lib_path, symbol, is_pure
|
||||
)
|
||||
|
||||
def create_get_program_id(self, axis: int) -> ir.Value:
|
||||
assert 0 <= axis < 3
|
||||
return tt_dialect.get_program_id(axis)
|
||||
|
||||
def create_get_num_programs(self, axis: int) -> ir.Value:
|
||||
return tt_dialect.get_num_programs(axis)
|
||||
|
||||
@ -684,27 +663,6 @@ class builder:
|
||||
) -> ir.Value:
|
||||
return tt_dialect.dot(a, b, c, allow_tf32, max_num_imprecise_acc)
|
||||
|
||||
def create_exp(self, val: ir.Value) -> ir.Value:
|
||||
return math_dialect.exp(val)
|
||||
|
||||
def create_cos(self, val: ir.Value) -> ir.Value:
|
||||
return math_dialect.cos(val)
|
||||
|
||||
def create_sin(self, val: ir.Value) -> ir.Value:
|
||||
return math_dialect.sin(val)
|
||||
|
||||
def create_log(self, val: ir.Value) -> ir.Value:
|
||||
return math_dialect.log(val)
|
||||
|
||||
def create_sqrt(self, val: ir.Value) -> ir.Value:
|
||||
return math_dialect.sqrt(val)
|
||||
|
||||
def create_fabs(self, val: ir.Value) -> ir.Value:
|
||||
return math_dialect.absf(val)
|
||||
|
||||
def create_iabs(self, val: ir.Value) -> ir.Value:
|
||||
return math_dialect.absi(val)
|
||||
|
||||
def create_reduce(
|
||||
self, operands: Sequence[ir.Value], axis: int
|
||||
) -> tt_dialect.ReduceOp:
|
||||
@ -898,20 +856,101 @@ class tensor(tl.core.tensor):
|
||||
def __eq__(self, other):
|
||||
return semantic.equal(self, _to_tensor(other))
|
||||
|
||||
__getitem__ = wrap_with_builder(tl.tensor.__getitem__)
|
||||
def __getitem__(self, slices) -> tensor:
|
||||
if isinstance(slices, (slice, constexpr)):
|
||||
slices = [slices]
|
||||
t = self
|
||||
for axis, s in enumerate(slices):
|
||||
if s is None or isinstance(s, constexpr) and s.value is None:
|
||||
t = expand_dims(t, axis)
|
||||
elif (
|
||||
isinstance(s, slice)
|
||||
and s.start is s.stop is s.step is None
|
||||
):
|
||||
pass
|
||||
else:
|
||||
raise IndexError(f"unsupported tensor index: {s}")
|
||||
return t
|
||||
|
||||
to = wrap_with_builder(tl.tensor.to)
|
||||
|
||||
|
||||
program_id = wrap_with_builder(tl.core.program_id)
|
||||
def program_id(axis: int) -> tensor:
|
||||
if axis not in range(3):
|
||||
raise ValueError(f"axis must be in [0, 3), but got: {axis}")
|
||||
return tensor(tt_dialect.get_program_id(axis), tl.int32)
|
||||
|
||||
|
||||
load = wrap_with_builder(tl.core.load)
|
||||
store = wrap_with_builder(tl.core.store)
|
||||
|
||||
arange = wrap_with_builder(tl.core.arange)
|
||||
broadcast_to = wrap_with_builder(tl.core.broadcast_to)
|
||||
expand_dims = wrap_with_builder(tl.core.expand_dims)
|
||||
reshape = wrap_with_builder(tl.core.reshape)
|
||||
|
||||
def arange(start: int, end: int) -> tensor:
|
||||
if end <= start:
|
||||
raise ValueError(
|
||||
f"end must be greater than start, but got: {end} <= {start}"
|
||||
)
|
||||
if max(start, end) >= 2**32:
|
||||
raise ValueError("start and end must fit in int32")
|
||||
ty = block_type(tl.int32, [end - start])
|
||||
ir_ty = ir.RankedTensorType.get(
|
||||
[end - start], ir.IntegerType.get_signless(32)
|
||||
)
|
||||
return tensor(tt_dialect.make_range(ir_ty, start, end), ty)
|
||||
|
||||
|
||||
def broadcast_to(x: object, shape: Sequence[int | constexpr]) -> tensor:
|
||||
x = _to_tensor(x)
|
||||
if not x.type.is_block():
|
||||
return splat(x, shape)
|
||||
elif x.shape == shape:
|
||||
return x
|
||||
shape = [dim.__index__() for dim in shape]
|
||||
x_ir_type = ir.RankedTensorType(x.handle.type)
|
||||
result_ir_type = ir.RankedTensorType.get(
|
||||
shape, x_ir_type.element_type, x_ir_type.encoding
|
||||
)
|
||||
return tensor(
|
||||
tt_dialect.broadcast(result_ir_type, x.handle),
|
||||
block_type(x.dtype, shape),
|
||||
)
|
||||
|
||||
|
||||
def splat(x: object, shape: Sequence[int | constexpr]) -> tensor:
|
||||
x = _to_tensor(x)
|
||||
if x.type.is_block():
|
||||
raise ValueError("cannot splat a block tensor")
|
||||
if len(shape) == 0:
|
||||
return x
|
||||
shape = [dim.__index__() for dim in shape]
|
||||
result_ir_type = ir.RankedTensorType.get(shape, x.handle.type)
|
||||
return tensor(
|
||||
tt_dialect.splat(result_ir_type, x.handle), block_type(x.dtype, shape)
|
||||
)
|
||||
|
||||
|
||||
def expand_dims(x: object, axis: int) -> tensor:
|
||||
x = _to_tensor(x)
|
||||
dst_shape = [dim.__index__() for dim in x.shape]
|
||||
dst_shape.insert(axis, 1)
|
||||
if not x.type.is_block():
|
||||
return splat(input, dst_shape)
|
||||
return tensor(
|
||||
tt_dialect.expand_dims(x.handle, axis),
|
||||
block_type(x.dtype, dst_shape),
|
||||
)
|
||||
|
||||
|
||||
def reshape(x: tensor, dst_shape: Sequence[int]) -> tensor:
|
||||
x_ir_type = ir.RankedTensorType(x.handle.type)
|
||||
result_ir_type = ir.RankedTensorType.get(
|
||||
dst_shape, x_ir_type.element_type, x_ir_type.encoding
|
||||
)
|
||||
return tensor(
|
||||
tt_dialect.reshape(result_ir_type, x.handle, allow_reorder=False),
|
||||
block_type(x.dtype, dst_shape),
|
||||
)
|
||||
|
||||
|
||||
dot = wrap_with_builder(tl.core.dot)
|
||||
|
||||
@ -924,15 +963,56 @@ atomic_or = wrap_with_builder(tl.core.atomic_or)
|
||||
atomic_xor = wrap_with_builder(tl.core.atomic_xor)
|
||||
atomic_cas = wrap_with_builder(tl.atomic_cas)
|
||||
|
||||
abs = wrap_with_builder(tl.abs)
|
||||
exp = wrap_with_builder(tl.exp)
|
||||
log = wrap_with_builder(tl.log)
|
||||
sqrt = wrap_with_builder(tl.sqrt)
|
||||
sin = wrap_with_builder(tl.sin)
|
||||
cos = wrap_with_builder(tl.cos)
|
||||
|
||||
def abs(x: object) -> tensor:
|
||||
x = _to_tensor(x)
|
||||
dtype = x.dtype
|
||||
if dtype.is_floating():
|
||||
return tensor(math_dialect.absf(x.handle), x.type)
|
||||
elif dtype.is_int_signed():
|
||||
return tensor(math_dialect.absi(x.handle), x.type)
|
||||
elif dtype.is_int_unsigned():
|
||||
return x
|
||||
else:
|
||||
raise ValueError(f"unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
def multiple_of(x: tensor, values: list[int]) -> tl.tensor:
|
||||
def exp(x: object) -> tensor:
|
||||
x = _to_tensor(x)
|
||||
if x.dtype != float32 and x.dtype != float64:
|
||||
raise ValueError(f"unsupported dtype: {x.dtype}")
|
||||
return tensor(math_dialect.exp(x.handle), x.type)
|
||||
|
||||
|
||||
def log(x: object) -> tensor:
|
||||
x = _to_tensor(x)
|
||||
if x.dtype != float32 and x.dtype != float64:
|
||||
raise ValueError(f"unsupported dtype: {x.dtype}")
|
||||
return tensor(math_dialect.log(x.handle), x.type)
|
||||
|
||||
|
||||
def sqrt(x: object) -> tensor:
|
||||
x = _to_tensor(x)
|
||||
if x.dtype != float32 and x.dtype != float64:
|
||||
raise ValueError(f"unsupported dtype: {x.dtype}")
|
||||
return tensor(math_dialect.sqrt(x.handle), x.type)
|
||||
|
||||
|
||||
def sin(x: object) -> tensor:
|
||||
x = _to_tensor(x)
|
||||
if x.dtype != float32 and x.dtype != float64:
|
||||
raise ValueError(f"unsupported dtype: {x.dtype}")
|
||||
return tensor(math_dialect.sin(x.handle), x.type)
|
||||
|
||||
|
||||
def cos(x: object) -> tensor:
|
||||
x = _to_tensor(x)
|
||||
if x.dtype != float32 and x.dtype != float64:
|
||||
raise ValueError(f"unsupported dtype: {x.dtype}")
|
||||
return tensor(math_dialect.cos(x.handle), x.type)
|
||||
|
||||
|
||||
def multiple_of(x: tensor, values: Sequence[int]) -> tl.tensor:
|
||||
assert max(1, len(x.shape)) == len(values)
|
||||
set_attr(
|
||||
x.handle,
|
||||
@ -944,7 +1024,7 @@ def multiple_of(x: tensor, values: list[int]) -> tl.tensor:
|
||||
return x
|
||||
|
||||
|
||||
def max_contiguous(x: tensor, values: list[int]) -> tl.tensor:
|
||||
def max_contiguous(x: tensor, values: Sequence[int]) -> tl.tensor:
|
||||
assert len(x.shape) == len(values)
|
||||
set_attr(
|
||||
x.handle,
|
||||
@ -1010,7 +1090,6 @@ class semantic:
|
||||
ashr = wrap_with_builder(tl.semantic.ashr)
|
||||
cast = wrap_with_builder(tl.semantic.cast)
|
||||
equal = wrap_with_builder(tl.semantic.equal)
|
||||
expand_dims = wrap_with_builder(tl.semantic.expand_dims)
|
||||
floordiv = wrap_with_builder(tl.semantic.floordiv)
|
||||
greater_equal = wrap_with_builder(tl.semantic.greater_equal)
|
||||
greater_than = wrap_with_builder(tl.semantic.greater_than)
|
||||
|
Loading…
x
Reference in New Issue
Block a user