diff --git a/jax/_src/lib/BUILD b/jax/_src/lib/BUILD index 485fd5851..d0028a5f0 100644 --- a/jax/_src/lib/BUILD +++ b/jax/_src/lib/BUILD @@ -40,6 +40,7 @@ py_library_providing_imports_info( "//jaxlib", "//jaxlib:cpu_feature_guard", "//jaxlib:utils", + "//jaxlib/triton", "//jaxlib/mlir/_mlir_libs:register_jax_dialects", "//jaxlib/mlir:arithmetic_dialect", "//jaxlib/mlir:builtin_dialect", diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index 07d6d568e..a0d7a4e99 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -20,6 +20,7 @@ from __future__ import annotations import gc import pathlib import re +from typing import Any try: import jaxlib as jaxlib @@ -120,6 +121,12 @@ import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error import jaxlib.tpu_mosaic as tpu_mosaic # pytype: disable=import-error +triton_dialect: Any +try: + import jaxlib.triton.dialect as triton_dialect # pytype: disable=import-error +except ImportError: + triton_dialect = None + # Version number for MLIR:Python APIs, provided by jaxlib. mlir_api_version = xla_client.mlir_api_version diff --git a/jaxlib/triton/BUILD b/jaxlib/triton/BUILD new file mode 100644 index 000000000..07c9634a6 --- /dev/null +++ b/jaxlib/triton/BUILD @@ -0,0 +1,95 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//jaxlib:jax.bzl", "pybind_extension", "pytype_strict_library") +load("@llvm-project//mlir:tblgen.bzl", "gentbl_filegroup") + +licenses(["notice"]) + +package( + default_applicable_licenses = [], + default_visibility = ["//:__subpackages__"], +) + +pytype_strict_library( + name = "triton", + srcs = [ + "__init__.py", + "dialect.py", + ":_triton_gen", + ], + visibility = ["//visibility:public"], + deps = [ + ":_triton_ext", + "//jaxlib/mlir:core", + "//jaxlib/mlir:ir", + ], +) + +genrule( + name = "_triton_gen", + srcs = [":_triton_gen_raw"], + outs = ["_triton_gen.py"], + cmd = """ + echo '# pytype: skip-file' > $@ && \ + cat $(location :_triton_gen_raw) | sed -e 's/^from \\./from mlir\\.dialects\\./g' >> $@ + """, +) + +gentbl_filegroup( + name = "_triton_gen_raw", + tbl_outs = [ + ( + [ + "-gen-python-op-bindings", + "-bind-dialect=tt", + ], + "_triton_gen_raw.py", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "triton.td", + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "@triton//:td_files", + ], +) + +pybind_extension( + name = "_triton_ext", + srcs = ["_triton_ext.cc"], + pytype_deps = [ + "//jaxlib/mlir:ir", + ], + pytype_srcs = ["_triton_ext.pyi"], + visibility = ["//visibility:private"], + deps = [ + ":triton_dialect_capi", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps", + "@pybind11", + ], +) + +cc_library( + name = "triton_dialect_capi", + srcs = ["triton_dialect_capi.cc"], + hdrs = ["triton_dialect_capi.h"], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:IR", + "@triton//:TritonDialects", + ], +) diff --git a/jaxlib/triton/__init__.py b/jaxlib/triton/__init__.py new file mode 100644 index 000000000..6e67290de --- /dev/null +++ b/jaxlib/triton/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""High level APIs for working with the MLIR Triton dialect.""" diff --git a/jaxlib/triton/_triton_ext.cc b/jaxlib/triton/_triton_ext.cc new file mode 100644 index 000000000..2476ae9ae --- /dev/null +++ b/jaxlib/triton/_triton_ext.cc @@ -0,0 +1,55 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir-c/IR.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "pybind11/detail/common.h" +#include "jaxlib/triton/triton_dialect_capi.h" + +namespace py = pybind11; + +PYBIND11_MODULE(_triton_ext, m) { + // + // Dialects. + // + + m.def( + "register_dialect", + [](MlirContext context, bool load) { + MlirDialectHandle dialect = mlirGetDialectHandle__triton__(); + mlirDialectHandleRegisterDialect(dialect, context); + if (load) { + mlirDialectHandleLoadDialect(dialect, context); + } + }, + py::arg("context"), py::arg("load") = true); + + // + // Types. + // + + mlir::python::adaptors::mlir_type_subclass(m, "PointerType", + mlirTritonIsAPointerType) + .def_classmethod( + "get", + [](py::object cls, MlirType pointee_type, int64_t address_space) { + return cls(mlirTritonPointerTypeGet(pointee_type, address_space)); + }, + py::arg("cls"), py::arg("pointee_type"), py::arg("address_space"), + "Creates a PointerType type.") + .def_property_readonly("pointee_type", [](MlirType self) { + return mlirTritonPointerTypeGetPointeeType(self); + }); +} diff --git a/jaxlib/triton/_triton_ext.pyi b/jaxlib/triton/_triton_ext.pyi new file mode 100644 index 000000000..928455e99 --- /dev/null +++ b/jaxlib/triton/_triton_ext.pyi @@ -0,0 +1,29 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mlir import ir # type: ignore + + +def register_dialect(context: ir.Context, load: bool = ...): ... + + +class PointerType(ir.Type): + @classmethod + def get(cls, pointee_type: ir.Type, address_space: int) -> PointerType: ... + + @staticmethod + def isinstance(other: ir.Type) -> bool: ... + + @property + def pointee_type(self) -> ir.Type: ... diff --git a/jaxlib/triton/dialect.py b/jaxlib/triton/dialect.py new file mode 100644 index 000000000..3add96edf --- /dev/null +++ b/jaxlib/triton/dialect.py @@ -0,0 +1,20 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +"""Python bindings for the MLIR Triton dialect.""" + +from ._triton_ext import register_dialect, PointerType +from ._triton_gen import * # pylint: disable=wildcard-import diff --git a/jaxlib/triton/triton.td b/jaxlib/triton/triton.td new file mode 100644 index 000000000..ac8c7d43e --- /dev/null +++ b/jaxlib/triton/triton.td @@ -0,0 +1 @@ +include "triton/Dialect/Triton/IR/TritonOps.td" diff --git a/jaxlib/triton/triton_dialect_capi.cc b/jaxlib/triton/triton_dialect_capi.cc new file mode 100644 index 000000000..c60ebf476 --- /dev/null +++ b/jaxlib/triton/triton_dialect_capi.cc @@ -0,0 +1,44 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/triton/triton_dialect_capi.h" + +#include "llvm/include/llvm/Support/Casting.h" +#include "mlir/include/mlir-c/IR.h" +#include "mlir/include/mlir/CAPI/IR.h" +#include "mlir/include/mlir/CAPI/Registration.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" + +extern "C" { + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Triton, triton, + mlir::triton::TritonDialect); + +MlirType mlirTritonPointerTypeGet(MlirType pointeeType, int addressSpace) { + return wrap( + mlir::triton::PointerType::get(unwrap(pointeeType), addressSpace)); +} + +bool mlirTritonIsAPointerType(MlirType type) { + return llvm::isa(unwrap(type)); +} + +MlirType mlirTritonPointerTypeGetPointeeType(MlirType pointerType) { + return wrap(llvm::cast(unwrap(pointerType)) + .getPointeeType()); +} + +} // extern "C" diff --git a/jaxlib/triton/triton_dialect_capi.h b/jaxlib/triton/triton_dialect_capi.h new file mode 100644 index 000000000..4add25755 --- /dev/null +++ b/jaxlib/triton/triton_dialect_capi.h @@ -0,0 +1,35 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_TRITON_TRITON_DIALECT_CAPI_H_ +#define JAXLIB_TRITON_TRITON_DIALECT_CAPI_H_ + +#include "mlir/include/mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Triton, triton); + +MlirType mlirTritonPointerTypeGet(MlirType pointeeType, int addressSpace); +bool mlirTritonIsAPointerType(MlirType type); +MlirType mlirTritonPointerTypeGetPointeeType(MlirType pointerType); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // JAXLIB_TRITON_TRITON_DIALECT_CAPI_H_