diff --git a/jaxlib/mlir/_mlir_libs/_triton_ext.pyi b/jaxlib/mlir/_mlir_libs/_triton_ext.pyi index 96e2e5a14..1c7a3072f 100644 --- a/jaxlib/mlir/_mlir_libs/_triton_ext.pyi +++ b/jaxlib/mlir/_mlir_libs/_triton_ext.pyi @@ -26,6 +26,9 @@ class PointerType(ir.Type): @property def pointee_type(self) -> ir.Type: ... + @property + def address_space(self) -> int: ... + def infer_reduce_op_encoding( op_attribute: ir.Attribute, axis: int, diff --git a/jaxlib/mlir/_mlir_libs/triton_ext.cc b/jaxlib/mlir/_mlir_libs/triton_ext.cc index e7a6e3b6c..4b900b7c1 100644 --- a/jaxlib/mlir/_mlir_libs/triton_ext.cc +++ b/jaxlib/mlir/_mlir_libs/triton_ext.cc @@ -43,7 +43,7 @@ PYBIND11_MODULE(_triton_ext, m) { // mlir::python::adaptors::mlir_type_subclass(m, "PointerType", - mlirTritonIsAPointerType) + mlirTritonIsAPointer) .def_classmethod( "get", [](py::object cls, MlirType pointee_type, int64_t address_space) { @@ -53,6 +53,9 @@ PYBIND11_MODULE(_triton_ext, m) { "Creates a PointerType type.") .def_property_readonly("pointee_type", [](MlirType self) { return mlirTritonPointerTypeGetPointeeType(self); + }) + .def_property_readonly("address_space", [](MlirType self) { + return mlirTritonPointerTypeGetAddressSpace(self); }); // diff --git a/jaxlib/triton/compat.py b/jaxlib/triton/compat.py index 65e7b0852..b08f1b942 100644 --- a/jaxlib/triton/compat.py +++ b/jaxlib/triton/compat.py @@ -282,15 +282,15 @@ def _infer_load_return_type(ptr: ir.Value) -> ir.Type: return ptr_type.pointee_type -def load( - ptr: tensor, - mask: tensor | None = None, - other: tensor | None = None, +def _load( + ptr: ir.Value, + mask: ir.Value | None = None, + other: ir.Value | None = None, *, cache_modifier: str | None = None, eviction_policy: str | None = None, is_volatile: bool = False, -) -> tensor: +) -> ir.Value: if cache_modifier is None: cache_modifier = tt_dialect.CacheModifier.NONE elif cache_modifier == ".ca" or cache_modifier == ".cg": @@ -307,62 +307,78 @@ def load( f"unsupported eviction policy: {eviction_policy}" ) from None - if ptr.type.is_ptr() and ptr.type.element_ty.is_block(): - # TODO(slebedev): Support load from a block pointer. - raise NotImplementedError("loading from a block pointer is not supported") - if not ptr.dtype.is_ptr(): - raise ValueError(f"unsupported pointer dtype: {ptr.dtype}") - if other is not None: - if mask is None: - raise ValueError("other requires mask to be provided") - assert mask.shape == other.shape == ptr.shape, ( - mask.shape, - other.shape, - ptr.shape, - ) - elif mask is not None: - assert mask.shape == ptr.shape - if not ptr.type.is_block(): - if other is not None and other.type.is_block(): + if PointerType.isinstance(ptr.type): + ptr_type = PointerType(ptr.type) + if ir.RankedTensorType.isinstance(ptr_type.pointee_type): + raise NotImplementedError("loading from a block pointer is not supported") + + ptr_type = _element_type(ptr.type) + if not PointerType.isinstance(ptr_type): + raise ValueError(f"unsupported pointer type: {ptr_type}") + ptr_type = PointerType(ptr_type) + if other is not None and mask is None: + raise ValueError("other requires mask to be provided") + if not ir.RankedTensorType.isinstance(ptr.type): + if other is not None and ir.RankedTensorType.isinstance(other.type): raise ValueError("other cannot be a block if pointer is not a block") - if mask is not None and mask.type.is_block(): + if mask is not None and ir.RankedTensorType.isinstance(mask.type): raise ValueError("mask cannot be a block if pointer is not a block") - ptr_type = ptr.dtype - element_type = ptr_type.element_ty - - if element_type == int1: - # TODO(slebedev): Cast the result back to int1 before returning. - element_type = int8 - ptr_type = pointer_type(element_type, ptr_type.address_space) - ptr = semantic.cast(ptr, ptr_type) + pointee_type = ptr_type.pointee_type + is_int1 = isinstance(pointee_type, ir.IntegerType) and pointee_type.width == 1 + if is_int1: + pointee_type = ir.IntegerType.get_signless(8) + ptr = _cast(ptr, PointerType.get(pointee_type, ptr_type.address_space)) if other is not None: - other = semantic.cast(other, element_type) + other = _cast(other, pointee_type) - result_handle = tt_dialect.load( - _infer_load_return_type(ptr.handle), - ptr.handle, - mask=mask.handle if mask is not None else None, - other=other.handle if other is not None else None, + result = tt_dialect.load( + _infer_load_return_type(ptr), + ptr, + mask=mask, + other=other, cache=cache_modifier, evict=eviction_policy, is_volatile=is_volatile, ) + return ( + result if not is_int1 else _cast(result, ir.IntegerType.get_signless(1)) + ) + + +def load( + ptr: tensor, + mask: tensor | None = None, + other: tensor | None = None, + *, + cache_modifier: str | None = None, + eviction_policy: str | None = None, + is_volatile: bool = False, +) -> tensor: + element_type = ptr.dtype.element_ty + result_handle = _load( + ptr.handle, + mask.handle if mask is not None else None, + other.handle if other is not None else None, + cache_modifier=cache_modifier, + eviction_policy=eviction_policy, + is_volatile=is_volatile, + ) if ptr.type.is_block(): return tensor(result_handle, block_type(element_type, ptr.type.shape)) else: return tensor(result_handle, element_type) -def store( - ptr: tensor, - value: tensor, - mask: tensor | None = None, +def _store( + ptr: ir.Value, + value: ir.Value, + mask: ir.Value | None = None, *, cache_modifier: str | None = None, eviction_policy: str | None = None, -) -> tensor: +) -> ir.Value: if cache_modifier is None: cache_modifier = tt_dialect.CacheModifier.NONE elif cache_modifier != ".ca": @@ -379,39 +395,47 @@ def store( f"unsupported eviction policy: {eviction_policy}" ) from None - if ptr.type.is_ptr() and ptr.type.element_ty.is_block(): - # TODO(slebedev): Support load from a block pointer. - raise NotImplementedError("storing to a block pointer is not supported") + if PointerType.isinstance(ptr.type): + ptr_type = PointerType(ptr.type) + if ir.RankedTensorType.isinstance(ptr_type.pointee_type): + raise NotImplementedError("loading from a block pointer is not supported") - if not ptr.dtype.is_ptr(): - raise ValueError(f"unsupported pointer dtype: {ptr.dtype}") - assert value.shape == ptr.shape - if mask is not None: - assert mask.shape == ptr.shape - if not ptr.type.is_block(): - if value.type.is_block(): - raise ValueError("other cannot be a block if pointer is not a block") - if mask is not None and mask.type.is_block(): + ptr_type = _element_type(ptr.type) + if not PointerType.isinstance(ptr_type): + raise ValueError(f"unsupported pointer type: {ptr_type}") + ptr_type = PointerType(ptr_type) + if not ir.RankedTensorType.isinstance(ptr.type): + if ir.RankedTensorType.isinstance(value.type): + raise ValueError("value cannot be a block if pointer is not a block") + if mask is not None and ir.RankedTensorType.isinstance(mask.type): raise ValueError("mask cannot be a block if pointer is not a block") - ptr_type = ptr.dtype - element_type = ptr_type.element_ty + pointee_type = ptr_type.pointee_type + if isinstance(pointee_type, ir.IntegerType) and pointee_type.width == 1: + pointee_type = ir.IntegerType.get_signless(8) + ptr = _cast(ptr, PointerType.get(pointee_type, ptr_type.address_space)) - if element_type == int1: - # TODO(slebedev): Cast the result back to int1 before returning. - element_type = int8 - ptr_type = pointer_type(element_type, ptr_type.address_space) - ptr = semantic.cast(ptr, ptr_type) + value = _cast(value, pointee_type) + return tt_dialect.store( + ptr, value, mask=mask, cache=cache_modifier, evict=eviction_policy + ) - value = semantic.cast(value, element_type) +def store( + ptr: tensor, + value: tensor, + mask: tensor | None = None, + *, + cache_modifier: str | None = None, + eviction_policy: str | None = None, +) -> tensor: return tensor( - tt_dialect.store( + _store( ptr.handle, value.handle, - mask=mask.handle if mask is not None else None, - cache=cache_modifier, - evict=eviction_policy, + mask.handle if mask is not None else None, + cache_modifier=cache_modifier, + eviction_policy=eviction_policy, ), void, ) diff --git a/jaxlib/triton/triton_dialect_capi.cc b/jaxlib/triton/triton_dialect_capi.cc index fdc090cf2..6a46d2914 100644 --- a/jaxlib/triton/triton_dialect_capi.cc +++ b/jaxlib/triton/triton_dialect_capi.cc @@ -34,7 +34,7 @@ MlirType mlirTritonPointerTypeGet(MlirType pointeeType, int addressSpace) { mlir::triton::PointerType::get(unwrap(pointeeType), addressSpace)); } -bool mlirTritonIsAPointerType(MlirType type) { +bool mlirTritonIsAPointer(MlirType type) { return llvm::isa(unwrap(type)); } @@ -43,6 +43,11 @@ MlirType mlirTritonPointerTypeGetPointeeType(MlirType pointerType) { .getPointeeType()); } +int mlirTritonPointerTypeGetAddressSpace(MlirType pointerType) { + return llvm::cast(unwrap(pointerType)) + .getAddressSpace(); +} + MlirAttribute mlirTritonInferReduceOpEncoding(MlirAttribute operandEncoding, int axis) { auto opEncoding = unwrap(operandEncoding); diff --git a/jaxlib/triton/triton_dialect_capi.h b/jaxlib/triton/triton_dialect_capi.h index 9b267ad71..8c27b5b82 100644 --- a/jaxlib/triton/triton_dialect_capi.h +++ b/jaxlib/triton/triton_dialect_capi.h @@ -27,9 +27,11 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Triton, triton); MLIR_CAPI_EXPORTED MlirType mlirTritonPointerTypeGet(MlirType pointeeType, int addressSpace); -MLIR_CAPI_EXPORTED bool mlirTritonIsAPointerType(MlirType type); +MLIR_CAPI_EXPORTED bool mlirTritonIsAPointer(MlirType type); MLIR_CAPI_EXPORTED MlirType mlirTritonPointerTypeGetPointeeType(MlirType pointerType); +MLIR_CAPI_EXPORTED int +mlirTritonPointerTypeGetAddressSpace(MlirType pointerType); MLIR_CAPI_EXPORTED MlirAttribute mlirTritonInferReduceOpEncoding(MlirAttribute operandEncoding, int axis);