mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Generate Python bindings for the Triton MLIR dialect
The bindings are not yet included in the jaxlib wheel. I will do that in a follow up PR. PiperOrigin-RevId: 595174466
This commit is contained in:
parent
15f4a8d2ec
commit
e6c890171b
@ -40,6 +40,7 @@ py_library_providing_imports_info(
|
||||
"//jaxlib",
|
||||
"//jaxlib:cpu_feature_guard",
|
||||
"//jaxlib:utils",
|
||||
"//jaxlib/triton",
|
||||
"//jaxlib/mlir/_mlir_libs:register_jax_dialects",
|
||||
"//jaxlib/mlir:arithmetic_dialect",
|
||||
"//jaxlib/mlir:builtin_dialect",
|
||||
|
@ -20,6 +20,7 @@ from __future__ import annotations
|
||||
import gc
|
||||
import pathlib
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
import jaxlib as jaxlib
|
||||
@ -120,6 +121,12 @@ import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error
|
||||
|
||||
import jaxlib.tpu_mosaic as tpu_mosaic # pytype: disable=import-error
|
||||
|
||||
triton_dialect: Any
|
||||
try:
|
||||
import jaxlib.triton.dialect as triton_dialect # pytype: disable=import-error
|
||||
except ImportError:
|
||||
triton_dialect = None
|
||||
|
||||
# Version number for MLIR:Python APIs, provided by jaxlib.
|
||||
mlir_api_version = xla_client.mlir_api_version
|
||||
|
||||
|
95
jaxlib/triton/BUILD
Normal file
95
jaxlib/triton/BUILD
Normal file
@ -0,0 +1,95 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("//jaxlib:jax.bzl", "pybind_extension", "pytype_strict_library")
|
||||
load("@llvm-project//mlir:tblgen.bzl", "gentbl_filegroup")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = ["//:__subpackages__"],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "triton",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"dialect.py",
|
||||
":_triton_gen",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":_triton_ext",
|
||||
"//jaxlib/mlir:core",
|
||||
"//jaxlib/mlir:ir",
|
||||
],
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "_triton_gen",
|
||||
srcs = [":_triton_gen_raw"],
|
||||
outs = ["_triton_gen.py"],
|
||||
cmd = """
|
||||
echo '# pytype: skip-file' > $@ && \
|
||||
cat $(location :_triton_gen_raw) | sed -e 's/^from \\./from mlir\\.dialects\\./g' >> $@
|
||||
""",
|
||||
)
|
||||
|
||||
gentbl_filegroup(
|
||||
name = "_triton_gen_raw",
|
||||
tbl_outs = [
|
||||
(
|
||||
[
|
||||
"-gen-python-op-bindings",
|
||||
"-bind-dialect=tt",
|
||||
],
|
||||
"_triton_gen_raw.py",
|
||||
),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "triton.td",
|
||||
deps = [
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
"@triton//:td_files",
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_triton_ext",
|
||||
srcs = ["_triton_ext.cc"],
|
||||
pytype_deps = [
|
||||
"//jaxlib/mlir:ir",
|
||||
],
|
||||
pytype_srcs = ["_triton_ext.pyi"],
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":triton_dialect_capi",
|
||||
"@llvm-project//mlir:CAPIIR",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "triton_dialect_capi",
|
||||
srcs = ["triton_dialect_capi.cc"],
|
||||
hdrs = ["triton_dialect_capi.h"],
|
||||
deps = [
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:CAPIIR",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@triton//:TritonDialects",
|
||||
],
|
||||
)
|
15
jaxlib/triton/__init__.py
Normal file
15
jaxlib/triton/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""High level APIs for working with the MLIR Triton dialect."""
|
55
jaxlib/triton/_triton_ext.cc
Normal file
55
jaxlib/triton/_triton_ext.cc
Normal file
@ -0,0 +1,55 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir/Bindings/Python/PybindAdaptors.h"
|
||||
#include "pybind11/detail/common.h"
|
||||
#include "jaxlib/triton/triton_dialect_capi.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
PYBIND11_MODULE(_triton_ext, m) {
|
||||
//
|
||||
// Dialects.
|
||||
//
|
||||
|
||||
m.def(
|
||||
"register_dialect",
|
||||
[](MlirContext context, bool load) {
|
||||
MlirDialectHandle dialect = mlirGetDialectHandle__triton__();
|
||||
mlirDialectHandleRegisterDialect(dialect, context);
|
||||
if (load) {
|
||||
mlirDialectHandleLoadDialect(dialect, context);
|
||||
}
|
||||
},
|
||||
py::arg("context"), py::arg("load") = true);
|
||||
|
||||
//
|
||||
// Types.
|
||||
//
|
||||
|
||||
mlir::python::adaptors::mlir_type_subclass(m, "PointerType",
|
||||
mlirTritonIsAPointerType)
|
||||
.def_classmethod(
|
||||
"get",
|
||||
[](py::object cls, MlirType pointee_type, int64_t address_space) {
|
||||
return cls(mlirTritonPointerTypeGet(pointee_type, address_space));
|
||||
},
|
||||
py::arg("cls"), py::arg("pointee_type"), py::arg("address_space"),
|
||||
"Creates a PointerType type.")
|
||||
.def_property_readonly("pointee_type", [](MlirType self) {
|
||||
return mlirTritonPointerTypeGetPointeeType(self);
|
||||
});
|
||||
}
|
29
jaxlib/triton/_triton_ext.pyi
Normal file
29
jaxlib/triton/_triton_ext.pyi
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from mlir import ir # type: ignore
|
||||
|
||||
|
||||
def register_dialect(context: ir.Context, load: bool = ...): ...
|
||||
|
||||
|
||||
class PointerType(ir.Type):
|
||||
@classmethod
|
||||
def get(cls, pointee_type: ir.Type, address_space: int) -> PointerType: ...
|
||||
|
||||
@staticmethod
|
||||
def isinstance(other: ir.Type) -> bool: ...
|
||||
|
||||
@property
|
||||
def pointee_type(self) -> ir.Type: ...
|
20
jaxlib/triton/dialect.py
Normal file
20
jaxlib/triton/dialect.py
Normal file
@ -0,0 +1,20 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
"""Python bindings for the MLIR Triton dialect."""
|
||||
|
||||
from ._triton_ext import register_dialect, PointerType
|
||||
from ._triton_gen import * # pylint: disable=wildcard-import
|
1
jaxlib/triton/triton.td
Normal file
1
jaxlib/triton/triton.td
Normal file
@ -0,0 +1 @@
|
||||
include "triton/Dialect/Triton/IR/TritonOps.td"
|
44
jaxlib/triton/triton_dialect_capi.cc
Normal file
44
jaxlib/triton/triton_dialect_capi.cc
Normal file
@ -0,0 +1,44 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/triton/triton_dialect_capi.h"
|
||||
|
||||
#include "llvm/include/llvm/Support/Casting.h"
|
||||
#include "mlir/include/mlir-c/IR.h"
|
||||
#include "mlir/include/mlir/CAPI/IR.h"
|
||||
#include "mlir/include/mlir/CAPI/Registration.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Triton, triton,
|
||||
mlir::triton::TritonDialect);
|
||||
|
||||
MlirType mlirTritonPointerTypeGet(MlirType pointeeType, int addressSpace) {
|
||||
return wrap(
|
||||
mlir::triton::PointerType::get(unwrap(pointeeType), addressSpace));
|
||||
}
|
||||
|
||||
bool mlirTritonIsAPointerType(MlirType type) {
|
||||
return llvm::isa<mlir::triton::PointerType>(unwrap(type));
|
||||
}
|
||||
|
||||
MlirType mlirTritonPointerTypeGetPointeeType(MlirType pointerType) {
|
||||
return wrap(llvm::cast<mlir::triton::PointerType>(unwrap(pointerType))
|
||||
.getPointeeType());
|
||||
}
|
||||
|
||||
} // extern "C"
|
35
jaxlib/triton/triton_dialect_capi.h
Normal file
35
jaxlib/triton/triton_dialect_capi.h
Normal file
@ -0,0 +1,35 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef JAXLIB_TRITON_TRITON_DIALECT_CAPI_H_
|
||||
#define JAXLIB_TRITON_TRITON_DIALECT_CAPI_H_
|
||||
|
||||
#include "mlir/include/mlir-c/IR.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Triton, triton);
|
||||
|
||||
MlirType mlirTritonPointerTypeGet(MlirType pointeeType, int addressSpace);
|
||||
bool mlirTritonIsAPointerType(MlirType type);
|
||||
MlirType mlirTritonPointerTypeGetPointeeType(MlirType pointerType);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
|
||||
#endif // JAXLIB_TRITON_TRITON_DIALECT_CAPI_H_
|
Loading…
x
Reference in New Issue
Block a user