mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Added ir.Value-based versions of load and store in triton.compat
PiperOrigin-RevId: 606597830
This commit is contained in:
parent
21236f0c65
commit
6a7d1dceff
@ -26,6 +26,9 @@ class PointerType(ir.Type):
|
||||
@property
|
||||
def pointee_type(self) -> ir.Type: ...
|
||||
|
||||
@property
|
||||
def address_space(self) -> int: ...
|
||||
|
||||
def infer_reduce_op_encoding(
|
||||
op_attribute: ir.Attribute,
|
||||
axis: int,
|
||||
|
@ -43,7 +43,7 @@ PYBIND11_MODULE(_triton_ext, m) {
|
||||
//
|
||||
|
||||
mlir::python::adaptors::mlir_type_subclass(m, "PointerType",
|
||||
mlirTritonIsAPointerType)
|
||||
mlirTritonIsAPointer)
|
||||
.def_classmethod(
|
||||
"get",
|
||||
[](py::object cls, MlirType pointee_type, int64_t address_space) {
|
||||
@ -53,6 +53,9 @@ PYBIND11_MODULE(_triton_ext, m) {
|
||||
"Creates a PointerType type.")
|
||||
.def_property_readonly("pointee_type", [](MlirType self) {
|
||||
return mlirTritonPointerTypeGetPointeeType(self);
|
||||
})
|
||||
.def_property_readonly("address_space", [](MlirType self) {
|
||||
return mlirTritonPointerTypeGetAddressSpace(self);
|
||||
});
|
||||
|
||||
//
|
||||
|
@ -282,15 +282,15 @@ def _infer_load_return_type(ptr: ir.Value) -> ir.Type:
|
||||
return ptr_type.pointee_type
|
||||
|
||||
|
||||
def load(
|
||||
ptr: tensor,
|
||||
mask: tensor | None = None,
|
||||
other: tensor | None = None,
|
||||
def _load(
|
||||
ptr: ir.Value,
|
||||
mask: ir.Value | None = None,
|
||||
other: ir.Value | None = None,
|
||||
*,
|
||||
cache_modifier: str | None = None,
|
||||
eviction_policy: str | None = None,
|
||||
is_volatile: bool = False,
|
||||
) -> tensor:
|
||||
) -> ir.Value:
|
||||
if cache_modifier is None:
|
||||
cache_modifier = tt_dialect.CacheModifier.NONE
|
||||
elif cache_modifier == ".ca" or cache_modifier == ".cg":
|
||||
@ -307,62 +307,78 @@ def load(
|
||||
f"unsupported eviction policy: {eviction_policy}"
|
||||
) from None
|
||||
|
||||
if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
|
||||
# TODO(slebedev): Support load from a block pointer.
|
||||
raise NotImplementedError("loading from a block pointer is not supported")
|
||||
if not ptr.dtype.is_ptr():
|
||||
raise ValueError(f"unsupported pointer dtype: {ptr.dtype}")
|
||||
if other is not None:
|
||||
if mask is None:
|
||||
raise ValueError("other requires mask to be provided")
|
||||
assert mask.shape == other.shape == ptr.shape, (
|
||||
mask.shape,
|
||||
other.shape,
|
||||
ptr.shape,
|
||||
)
|
||||
elif mask is not None:
|
||||
assert mask.shape == ptr.shape
|
||||
if not ptr.type.is_block():
|
||||
if other is not None and other.type.is_block():
|
||||
if PointerType.isinstance(ptr.type):
|
||||
ptr_type = PointerType(ptr.type)
|
||||
if ir.RankedTensorType.isinstance(ptr_type.pointee_type):
|
||||
raise NotImplementedError("loading from a block pointer is not supported")
|
||||
|
||||
ptr_type = _element_type(ptr.type)
|
||||
if not PointerType.isinstance(ptr_type):
|
||||
raise ValueError(f"unsupported pointer type: {ptr_type}")
|
||||
ptr_type = PointerType(ptr_type)
|
||||
if other is not None and mask is None:
|
||||
raise ValueError("other requires mask to be provided")
|
||||
if not ir.RankedTensorType.isinstance(ptr.type):
|
||||
if other is not None and ir.RankedTensorType.isinstance(other.type):
|
||||
raise ValueError("other cannot be a block if pointer is not a block")
|
||||
if mask is not None and mask.type.is_block():
|
||||
if mask is not None and ir.RankedTensorType.isinstance(mask.type):
|
||||
raise ValueError("mask cannot be a block if pointer is not a block")
|
||||
|
||||
ptr_type = ptr.dtype
|
||||
element_type = ptr_type.element_ty
|
||||
|
||||
if element_type == int1:
|
||||
# TODO(slebedev): Cast the result back to int1 before returning.
|
||||
element_type = int8
|
||||
ptr_type = pointer_type(element_type, ptr_type.address_space)
|
||||
ptr = semantic.cast(ptr, ptr_type)
|
||||
pointee_type = ptr_type.pointee_type
|
||||
is_int1 = isinstance(pointee_type, ir.IntegerType) and pointee_type.width == 1
|
||||
if is_int1:
|
||||
pointee_type = ir.IntegerType.get_signless(8)
|
||||
ptr = _cast(ptr, PointerType.get(pointee_type, ptr_type.address_space))
|
||||
|
||||
if other is not None:
|
||||
other = semantic.cast(other, element_type)
|
||||
other = _cast(other, pointee_type)
|
||||
|
||||
result_handle = tt_dialect.load(
|
||||
_infer_load_return_type(ptr.handle),
|
||||
ptr.handle,
|
||||
mask=mask.handle if mask is not None else None,
|
||||
other=other.handle if other is not None else None,
|
||||
result = tt_dialect.load(
|
||||
_infer_load_return_type(ptr),
|
||||
ptr,
|
||||
mask=mask,
|
||||
other=other,
|
||||
cache=cache_modifier,
|
||||
evict=eviction_policy,
|
||||
is_volatile=is_volatile,
|
||||
)
|
||||
return (
|
||||
result if not is_int1 else _cast(result, ir.IntegerType.get_signless(1))
|
||||
)
|
||||
|
||||
|
||||
def load(
|
||||
ptr: tensor,
|
||||
mask: tensor | None = None,
|
||||
other: tensor | None = None,
|
||||
*,
|
||||
cache_modifier: str | None = None,
|
||||
eviction_policy: str | None = None,
|
||||
is_volatile: bool = False,
|
||||
) -> tensor:
|
||||
element_type = ptr.dtype.element_ty
|
||||
result_handle = _load(
|
||||
ptr.handle,
|
||||
mask.handle if mask is not None else None,
|
||||
other.handle if other is not None else None,
|
||||
cache_modifier=cache_modifier,
|
||||
eviction_policy=eviction_policy,
|
||||
is_volatile=is_volatile,
|
||||
)
|
||||
if ptr.type.is_block():
|
||||
return tensor(result_handle, block_type(element_type, ptr.type.shape))
|
||||
else:
|
||||
return tensor(result_handle, element_type)
|
||||
|
||||
|
||||
def store(
|
||||
ptr: tensor,
|
||||
value: tensor,
|
||||
mask: tensor | None = None,
|
||||
def _store(
|
||||
ptr: ir.Value,
|
||||
value: ir.Value,
|
||||
mask: ir.Value | None = None,
|
||||
*,
|
||||
cache_modifier: str | None = None,
|
||||
eviction_policy: str | None = None,
|
||||
) -> tensor:
|
||||
) -> ir.Value:
|
||||
if cache_modifier is None:
|
||||
cache_modifier = tt_dialect.CacheModifier.NONE
|
||||
elif cache_modifier != ".ca":
|
||||
@ -379,39 +395,47 @@ def store(
|
||||
f"unsupported eviction policy: {eviction_policy}"
|
||||
) from None
|
||||
|
||||
if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
|
||||
# TODO(slebedev): Support load from a block pointer.
|
||||
raise NotImplementedError("storing to a block pointer is not supported")
|
||||
if PointerType.isinstance(ptr.type):
|
||||
ptr_type = PointerType(ptr.type)
|
||||
if ir.RankedTensorType.isinstance(ptr_type.pointee_type):
|
||||
raise NotImplementedError("loading from a block pointer is not supported")
|
||||
|
||||
if not ptr.dtype.is_ptr():
|
||||
raise ValueError(f"unsupported pointer dtype: {ptr.dtype}")
|
||||
assert value.shape == ptr.shape
|
||||
if mask is not None:
|
||||
assert mask.shape == ptr.shape
|
||||
if not ptr.type.is_block():
|
||||
if value.type.is_block():
|
||||
raise ValueError("other cannot be a block if pointer is not a block")
|
||||
if mask is not None and mask.type.is_block():
|
||||
ptr_type = _element_type(ptr.type)
|
||||
if not PointerType.isinstance(ptr_type):
|
||||
raise ValueError(f"unsupported pointer type: {ptr_type}")
|
||||
ptr_type = PointerType(ptr_type)
|
||||
if not ir.RankedTensorType.isinstance(ptr.type):
|
||||
if ir.RankedTensorType.isinstance(value.type):
|
||||
raise ValueError("value cannot be a block if pointer is not a block")
|
||||
if mask is not None and ir.RankedTensorType.isinstance(mask.type):
|
||||
raise ValueError("mask cannot be a block if pointer is not a block")
|
||||
|
||||
ptr_type = ptr.dtype
|
||||
element_type = ptr_type.element_ty
|
||||
pointee_type = ptr_type.pointee_type
|
||||
if isinstance(pointee_type, ir.IntegerType) and pointee_type.width == 1:
|
||||
pointee_type = ir.IntegerType.get_signless(8)
|
||||
ptr = _cast(ptr, PointerType.get(pointee_type, ptr_type.address_space))
|
||||
|
||||
if element_type == int1:
|
||||
# TODO(slebedev): Cast the result back to int1 before returning.
|
||||
element_type = int8
|
||||
ptr_type = pointer_type(element_type, ptr_type.address_space)
|
||||
ptr = semantic.cast(ptr, ptr_type)
|
||||
value = _cast(value, pointee_type)
|
||||
return tt_dialect.store(
|
||||
ptr, value, mask=mask, cache=cache_modifier, evict=eviction_policy
|
||||
)
|
||||
|
||||
value = semantic.cast(value, element_type)
|
||||
|
||||
def store(
|
||||
ptr: tensor,
|
||||
value: tensor,
|
||||
mask: tensor | None = None,
|
||||
*,
|
||||
cache_modifier: str | None = None,
|
||||
eviction_policy: str | None = None,
|
||||
) -> tensor:
|
||||
return tensor(
|
||||
tt_dialect.store(
|
||||
_store(
|
||||
ptr.handle,
|
||||
value.handle,
|
||||
mask=mask.handle if mask is not None else None,
|
||||
cache=cache_modifier,
|
||||
evict=eviction_policy,
|
||||
mask.handle if mask is not None else None,
|
||||
cache_modifier=cache_modifier,
|
||||
eviction_policy=eviction_policy,
|
||||
),
|
||||
void,
|
||||
)
|
||||
|
@ -34,7 +34,7 @@ MlirType mlirTritonPointerTypeGet(MlirType pointeeType, int addressSpace) {
|
||||
mlir::triton::PointerType::get(unwrap(pointeeType), addressSpace));
|
||||
}
|
||||
|
||||
bool mlirTritonIsAPointerType(MlirType type) {
|
||||
bool mlirTritonIsAPointer(MlirType type) {
|
||||
return llvm::isa<mlir::triton::PointerType>(unwrap(type));
|
||||
}
|
||||
|
||||
@ -43,6 +43,11 @@ MlirType mlirTritonPointerTypeGetPointeeType(MlirType pointerType) {
|
||||
.getPointeeType());
|
||||
}
|
||||
|
||||
int mlirTritonPointerTypeGetAddressSpace(MlirType pointerType) {
|
||||
return llvm::cast<mlir::triton::PointerType>(unwrap(pointerType))
|
||||
.getAddressSpace();
|
||||
}
|
||||
|
||||
MlirAttribute mlirTritonInferReduceOpEncoding(MlirAttribute operandEncoding,
|
||||
int axis) {
|
||||
auto opEncoding = unwrap(operandEncoding);
|
||||
|
@ -27,9 +27,11 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Triton, triton);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirType mlirTritonPointerTypeGet(MlirType pointeeType,
|
||||
int addressSpace);
|
||||
MLIR_CAPI_EXPORTED bool mlirTritonIsAPointerType(MlirType type);
|
||||
MLIR_CAPI_EXPORTED bool mlirTritonIsAPointer(MlirType type);
|
||||
MLIR_CAPI_EXPORTED MlirType
|
||||
mlirTritonPointerTypeGetPointeeType(MlirType pointerType);
|
||||
MLIR_CAPI_EXPORTED int
|
||||
mlirTritonPointerTypeGetAddressSpace(MlirType pointerType);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirAttribute
|
||||
mlirTritonInferReduceOpEncoding(MlirAttribute operandEncoding, int axis);
|
||||
|
Loading…
x
Reference in New Issue
Block a user