Generate Python bindings for the Triton MLIR dialect

The bindings are not yet included in the jaxlib wheel. I will do that in a
follow up PR.

PiperOrigin-RevId: 595174466
This commit is contained in:
Sergei Lebedev 2024-01-02 11:54:20 -08:00 committed by jax authors
parent 15f4a8d2ec
commit e6c890171b
10 changed files with 302 additions and 0 deletions

View File

@ -40,6 +40,7 @@ py_library_providing_imports_info(
"//jaxlib",
"//jaxlib:cpu_feature_guard",
"//jaxlib:utils",
"//jaxlib/triton",
"//jaxlib/mlir/_mlir_libs:register_jax_dialects",
"//jaxlib/mlir:arithmetic_dialect",
"//jaxlib/mlir:builtin_dialect",

View File

@ -20,6 +20,7 @@ from __future__ import annotations
import gc
import pathlib
import re
from typing import Any
try:
import jaxlib as jaxlib
@ -120,6 +121,12 @@ import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error
import jaxlib.tpu_mosaic as tpu_mosaic # pytype: disable=import-error
triton_dialect: Any
try:
import jaxlib.triton.dialect as triton_dialect # pytype: disable=import-error
except ImportError:
triton_dialect = None
# Version number for MLIR:Python APIs, provided by jaxlib.
mlir_api_version = xla_client.mlir_api_version

95
jaxlib/triton/BUILD Normal file
View File

@ -0,0 +1,95 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("//jaxlib:jax.bzl", "pybind_extension", "pytype_strict_library")
load("@llvm-project//mlir:tblgen.bzl", "gentbl_filegroup")
licenses(["notice"])
package(
default_applicable_licenses = [],
default_visibility = ["//:__subpackages__"],
)
pytype_strict_library(
name = "triton",
srcs = [
"__init__.py",
"dialect.py",
":_triton_gen",
],
visibility = ["//visibility:public"],
deps = [
":_triton_ext",
"//jaxlib/mlir:core",
"//jaxlib/mlir:ir",
],
)
genrule(
name = "_triton_gen",
srcs = [":_triton_gen_raw"],
outs = ["_triton_gen.py"],
cmd = """
echo '# pytype: skip-file' > $@ && \
cat $(location :_triton_gen_raw) | sed -e 's/^from \\./from mlir\\.dialects\\./g' >> $@
""",
)
gentbl_filegroup(
name = "_triton_gen_raw",
tbl_outs = [
(
[
"-gen-python-op-bindings",
"-bind-dialect=tt",
],
"_triton_gen_raw.py",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "triton.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@triton//:td_files",
],
)
pybind_extension(
name = "_triton_ext",
srcs = ["_triton_ext.cc"],
pytype_deps = [
"//jaxlib/mlir:ir",
],
pytype_srcs = ["_triton_ext.pyi"],
visibility = ["//visibility:private"],
deps = [
":triton_dialect_capi",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps",
"@pybind11",
],
)
cc_library(
name = "triton_dialect_capi",
srcs = ["triton_dialect_capi.cc"],
hdrs = ["triton_dialect_capi.h"],
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:IR",
"@triton//:TritonDialects",
],
)

15
jaxlib/triton/__init__.py Normal file
View File

@ -0,0 +1,15 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""High level APIs for working with the MLIR Triton dialect."""

View File

@ -0,0 +1,55 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "pybind11/detail/common.h"
#include "jaxlib/triton/triton_dialect_capi.h"
namespace py = pybind11;
PYBIND11_MODULE(_triton_ext, m) {
//
// Dialects.
//
m.def(
"register_dialect",
[](MlirContext context, bool load) {
MlirDialectHandle dialect = mlirGetDialectHandle__triton__();
mlirDialectHandleRegisterDialect(dialect, context);
if (load) {
mlirDialectHandleLoadDialect(dialect, context);
}
},
py::arg("context"), py::arg("load") = true);
//
// Types.
//
mlir::python::adaptors::mlir_type_subclass(m, "PointerType",
mlirTritonIsAPointerType)
.def_classmethod(
"get",
[](py::object cls, MlirType pointee_type, int64_t address_space) {
return cls(mlirTritonPointerTypeGet(pointee_type, address_space));
},
py::arg("cls"), py::arg("pointee_type"), py::arg("address_space"),
"Creates a PointerType type.")
.def_property_readonly("pointee_type", [](MlirType self) {
return mlirTritonPointerTypeGetPointeeType(self);
});
}

View File

@ -0,0 +1,29 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mlir import ir # type: ignore
def register_dialect(context: ir.Context, load: bool = ...): ...
class PointerType(ir.Type):
@classmethod
def get(cls, pointee_type: ir.Type, address_space: int) -> PointerType: ...
@staticmethod
def isinstance(other: ir.Type) -> bool: ...
@property
def pointee_type(self) -> ir.Type: ...

20
jaxlib/triton/dialect.py Normal file
View File

@ -0,0 +1,20 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""Python bindings for the MLIR Triton dialect."""
from ._triton_ext import register_dialect, PointerType
from ._triton_gen import * # pylint: disable=wildcard-import

1
jaxlib/triton/triton.td Normal file
View File

@ -0,0 +1 @@
include "triton/Dialect/Triton/IR/TritonOps.td"

View File

@ -0,0 +1,44 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/triton/triton_dialect_capi.h"
#include "llvm/include/llvm/Support/Casting.h"
#include "mlir/include/mlir-c/IR.h"
#include "mlir/include/mlir/CAPI/IR.h"
#include "mlir/include/mlir/CAPI/Registration.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
extern "C" {
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Triton, triton,
mlir::triton::TritonDialect);
MlirType mlirTritonPointerTypeGet(MlirType pointeeType, int addressSpace) {
return wrap(
mlir::triton::PointerType::get(unwrap(pointeeType), addressSpace));
}
bool mlirTritonIsAPointerType(MlirType type) {
return llvm::isa<mlir::triton::PointerType>(unwrap(type));
}
MlirType mlirTritonPointerTypeGetPointeeType(MlirType pointerType) {
return wrap(llvm::cast<mlir::triton::PointerType>(unwrap(pointerType))
.getPointeeType());
}
} // extern "C"

View File

@ -0,0 +1,35 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_TRITON_TRITON_DIALECT_CAPI_H_
#define JAXLIB_TRITON_TRITON_DIALECT_CAPI_H_
#include "mlir/include/mlir-c/IR.h"
#ifdef __cplusplus
extern "C" {
#endif
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Triton, triton);
MlirType mlirTritonPointerTypeGet(MlirType pointeeType, int addressSpace);
bool mlirTritonIsAPointerType(MlirType type);
MlirType mlirTritonPointerTypeGetPointeeType(MlirType pointerType);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // JAXLIB_TRITON_TRITON_DIALECT_CAPI_H_