Added ir.Value-based versions of load and store in triton.compat

PiperOrigin-RevId: 606597830
This commit is contained in:
Sergei Lebedev 2024-02-13 06:12:54 -08:00 committed by jax authors
parent 21236f0c65
commit 6a7d1dceff
5 changed files with 106 additions and 69 deletions

View File

@ -26,6 +26,9 @@ class PointerType(ir.Type):
@property
def pointee_type(self) -> ir.Type: ...
@property
def address_space(self) -> int: ...
def infer_reduce_op_encoding(
op_attribute: ir.Attribute,
axis: int,

View File

@ -43,7 +43,7 @@ PYBIND11_MODULE(_triton_ext, m) {
//
mlir::python::adaptors::mlir_type_subclass(m, "PointerType",
mlirTritonIsAPointerType)
mlirTritonIsAPointer)
.def_classmethod(
"get",
[](py::object cls, MlirType pointee_type, int64_t address_space) {
@ -53,6 +53,9 @@ PYBIND11_MODULE(_triton_ext, m) {
"Creates a PointerType type.")
.def_property_readonly("pointee_type", [](MlirType self) {
return mlirTritonPointerTypeGetPointeeType(self);
})
.def_property_readonly("address_space", [](MlirType self) {
return mlirTritonPointerTypeGetAddressSpace(self);
});
//

View File

@ -282,15 +282,15 @@ def _infer_load_return_type(ptr: ir.Value) -> ir.Type:
return ptr_type.pointee_type
def load(
ptr: tensor,
mask: tensor | None = None,
other: tensor | None = None,
def _load(
ptr: ir.Value,
mask: ir.Value | None = None,
other: ir.Value | None = None,
*,
cache_modifier: str | None = None,
eviction_policy: str | None = None,
is_volatile: bool = False,
) -> tensor:
) -> ir.Value:
if cache_modifier is None:
cache_modifier = tt_dialect.CacheModifier.NONE
elif cache_modifier == ".ca" or cache_modifier == ".cg":
@ -307,62 +307,78 @@ def load(
f"unsupported eviction policy: {eviction_policy}"
) from None
if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
# TODO(slebedev): Support load from a block pointer.
raise NotImplementedError("loading from a block pointer is not supported")
if not ptr.dtype.is_ptr():
raise ValueError(f"unsupported pointer dtype: {ptr.dtype}")
if other is not None:
if mask is None:
raise ValueError("other requires mask to be provided")
assert mask.shape == other.shape == ptr.shape, (
mask.shape,
other.shape,
ptr.shape,
)
elif mask is not None:
assert mask.shape == ptr.shape
if not ptr.type.is_block():
if other is not None and other.type.is_block():
if PointerType.isinstance(ptr.type):
ptr_type = PointerType(ptr.type)
if ir.RankedTensorType.isinstance(ptr_type.pointee_type):
raise NotImplementedError("loading from a block pointer is not supported")
ptr_type = _element_type(ptr.type)
if not PointerType.isinstance(ptr_type):
raise ValueError(f"unsupported pointer type: {ptr_type}")
ptr_type = PointerType(ptr_type)
if other is not None and mask is None:
raise ValueError("other requires mask to be provided")
if not ir.RankedTensorType.isinstance(ptr.type):
if other is not None and ir.RankedTensorType.isinstance(other.type):
raise ValueError("other cannot be a block if pointer is not a block")
if mask is not None and mask.type.is_block():
if mask is not None and ir.RankedTensorType.isinstance(mask.type):
raise ValueError("mask cannot be a block if pointer is not a block")
ptr_type = ptr.dtype
element_type = ptr_type.element_ty
if element_type == int1:
# TODO(slebedev): Cast the result back to int1 before returning.
element_type = int8
ptr_type = pointer_type(element_type, ptr_type.address_space)
ptr = semantic.cast(ptr, ptr_type)
pointee_type = ptr_type.pointee_type
is_int1 = isinstance(pointee_type, ir.IntegerType) and pointee_type.width == 1
if is_int1:
pointee_type = ir.IntegerType.get_signless(8)
ptr = _cast(ptr, PointerType.get(pointee_type, ptr_type.address_space))
if other is not None:
other = semantic.cast(other, element_type)
other = _cast(other, pointee_type)
result_handle = tt_dialect.load(
_infer_load_return_type(ptr.handle),
ptr.handle,
mask=mask.handle if mask is not None else None,
other=other.handle if other is not None else None,
result = tt_dialect.load(
_infer_load_return_type(ptr),
ptr,
mask=mask,
other=other,
cache=cache_modifier,
evict=eviction_policy,
is_volatile=is_volatile,
)
return (
result if not is_int1 else _cast(result, ir.IntegerType.get_signless(1))
)
def load(
ptr: tensor,
mask: tensor | None = None,
other: tensor | None = None,
*,
cache_modifier: str | None = None,
eviction_policy: str | None = None,
is_volatile: bool = False,
) -> tensor:
element_type = ptr.dtype.element_ty
result_handle = _load(
ptr.handle,
mask.handle if mask is not None else None,
other.handle if other is not None else None,
cache_modifier=cache_modifier,
eviction_policy=eviction_policy,
is_volatile=is_volatile,
)
if ptr.type.is_block():
return tensor(result_handle, block_type(element_type, ptr.type.shape))
else:
return tensor(result_handle, element_type)
def store(
ptr: tensor,
value: tensor,
mask: tensor | None = None,
def _store(
ptr: ir.Value,
value: ir.Value,
mask: ir.Value | None = None,
*,
cache_modifier: str | None = None,
eviction_policy: str | None = None,
) -> tensor:
) -> ir.Value:
if cache_modifier is None:
cache_modifier = tt_dialect.CacheModifier.NONE
elif cache_modifier != ".ca":
@ -379,39 +395,47 @@ def store(
f"unsupported eviction policy: {eviction_policy}"
) from None
if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
# TODO(slebedev): Support load from a block pointer.
raise NotImplementedError("storing to a block pointer is not supported")
if PointerType.isinstance(ptr.type):
ptr_type = PointerType(ptr.type)
if ir.RankedTensorType.isinstance(ptr_type.pointee_type):
raise NotImplementedError("loading from a block pointer is not supported")
if not ptr.dtype.is_ptr():
raise ValueError(f"unsupported pointer dtype: {ptr.dtype}")
assert value.shape == ptr.shape
if mask is not None:
assert mask.shape == ptr.shape
if not ptr.type.is_block():
if value.type.is_block():
raise ValueError("other cannot be a block if pointer is not a block")
if mask is not None and mask.type.is_block():
ptr_type = _element_type(ptr.type)
if not PointerType.isinstance(ptr_type):
raise ValueError(f"unsupported pointer type: {ptr_type}")
ptr_type = PointerType(ptr_type)
if not ir.RankedTensorType.isinstance(ptr.type):
if ir.RankedTensorType.isinstance(value.type):
raise ValueError("value cannot be a block if pointer is not a block")
if mask is not None and ir.RankedTensorType.isinstance(mask.type):
raise ValueError("mask cannot be a block if pointer is not a block")
ptr_type = ptr.dtype
element_type = ptr_type.element_ty
pointee_type = ptr_type.pointee_type
if isinstance(pointee_type, ir.IntegerType) and pointee_type.width == 1:
pointee_type = ir.IntegerType.get_signless(8)
ptr = _cast(ptr, PointerType.get(pointee_type, ptr_type.address_space))
if element_type == int1:
# TODO(slebedev): Cast the result back to int1 before returning.
element_type = int8
ptr_type = pointer_type(element_type, ptr_type.address_space)
ptr = semantic.cast(ptr, ptr_type)
value = _cast(value, pointee_type)
return tt_dialect.store(
ptr, value, mask=mask, cache=cache_modifier, evict=eviction_policy
)
value = semantic.cast(value, element_type)
def store(
ptr: tensor,
value: tensor,
mask: tensor | None = None,
*,
cache_modifier: str | None = None,
eviction_policy: str | None = None,
) -> tensor:
return tensor(
tt_dialect.store(
_store(
ptr.handle,
value.handle,
mask=mask.handle if mask is not None else None,
cache=cache_modifier,
evict=eviction_policy,
mask.handle if mask is not None else None,
cache_modifier=cache_modifier,
eviction_policy=eviction_policy,
),
void,
)

View File

@ -34,7 +34,7 @@ MlirType mlirTritonPointerTypeGet(MlirType pointeeType, int addressSpace) {
mlir::triton::PointerType::get(unwrap(pointeeType), addressSpace));
}
bool mlirTritonIsAPointerType(MlirType type) {
bool mlirTritonIsAPointer(MlirType type) {
return llvm::isa<mlir::triton::PointerType>(unwrap(type));
}
@ -43,6 +43,11 @@ MlirType mlirTritonPointerTypeGetPointeeType(MlirType pointerType) {
.getPointeeType());
}
int mlirTritonPointerTypeGetAddressSpace(MlirType pointerType) {
return llvm::cast<mlir::triton::PointerType>(unwrap(pointerType))
.getAddressSpace();
}
MlirAttribute mlirTritonInferReduceOpEncoding(MlirAttribute operandEncoding,
int axis) {
auto opEncoding = unwrap(operandEncoding);

View File

@ -27,9 +27,11 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Triton, triton);
MLIR_CAPI_EXPORTED MlirType mlirTritonPointerTypeGet(MlirType pointeeType,
int addressSpace);
MLIR_CAPI_EXPORTED bool mlirTritonIsAPointerType(MlirType type);
MLIR_CAPI_EXPORTED bool mlirTritonIsAPointer(MlirType type);
MLIR_CAPI_EXPORTED MlirType
mlirTritonPointerTypeGetPointeeType(MlirType pointerType);
MLIR_CAPI_EXPORTED int
mlirTritonPointerTypeGetAddressSpace(MlirType pointerType);
MLIR_CAPI_EXPORTED MlirAttribute
mlirTritonInferReduceOpEncoding(MlirAttribute operandEncoding, int axis);