diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index 57c576c4a..c634c52e9 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -15,6 +15,7 @@ load( "//jaxlib:jax.bzl", "py_extension", + "pybind_extension", "windows_cc_shared_mlir_library", ) load("//jaxlib:symlink_files.bzl", "symlink_inputs") @@ -117,6 +118,21 @@ py_library( ], ) +pybind_extension( + name = "_triton_ext", + srcs = ["triton_ext.cc"], + copts = COPTS, + linkopts = LINKOPTS, + pytype_srcs = ["_triton_ext.pyi"], + deps = [ + ":jaxlib_mlir_capi_shared_library", + "//jaxlib/triton:triton_dialect_capi_headers", + "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps", + "@pybind11", + ], +) + symlink_inputs( name = "_mlir_libs", rule = py_library, @@ -234,6 +250,7 @@ cc_library( name = "jaxlib_mlir_capi_objects", deps = [ "//jaxlib/mosaic:tpu_dialect_capi_objects", + "//jaxlib/triton:triton_dialect_capi_objects", "@llvm-project//mlir:CAPIArithObjects", "@llvm-project//mlir:CAPIIRObjects", "@llvm-project//mlir:CAPIMathObjects", diff --git a/jaxlib/triton/_triton_ext.pyi b/jaxlib/mlir/_mlir_libs/_triton_ext.pyi similarity index 100% rename from jaxlib/triton/_triton_ext.pyi rename to jaxlib/mlir/_mlir_libs/_triton_ext.pyi diff --git a/jaxlib/triton/_triton_ext.cc b/jaxlib/mlir/_mlir_libs/triton_ext.cc similarity index 100% rename from jaxlib/triton/_triton_ext.cc rename to jaxlib/mlir/_mlir_libs/triton_ext.cc diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 33c75eb5f..24fb55f05 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -327,6 +327,8 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}", + f"__main__/jaxlib/mlir/_mlir_libs/_triton_ext.{pyext}", + "__main__/jaxlib/mlir/_mlir_libs/_triton_ext.pyi", f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/register_jax_dialects.{pyext}", ], @@ -339,8 +341,6 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi "__main__/jaxlib/triton/__init__.py", "__main__/jaxlib/triton/compat.py", "__main__/jaxlib/triton/dialect.py", - f"__main__/jaxlib/triton/_triton_ext.{pyext}", - "__main__/jaxlib/triton/_triton_ext.pyi", ], ) patch_copy_mlir_import( diff --git a/jaxlib/triton/BUILD b/jaxlib/triton/BUILD index b1ebb49a6..d7db462b0 100644 --- a/jaxlib/triton/BUILD +++ b/jaxlib/triton/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//jaxlib:jax.bzl", "py_deps", "pybind_extension", "pytype_strict_library") +load("//jaxlib:jax.bzl", "py_deps", "pytype_strict_library") load("@rules_python//python:defs.bzl", "py_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_filegroup") @@ -47,9 +47,9 @@ pytype_strict_library( ], visibility = ["//visibility:public"], deps = [ - ":_triton_ext", "//jaxlib/mlir:core", "//jaxlib/mlir:ir", + "//jaxlib/mlir/_mlir_libs:_triton_ext", ], ) @@ -100,22 +100,6 @@ gentbl_filegroup( ], ) -pybind_extension( - name = "_triton_ext", - srcs = ["_triton_ext.cc"], - pytype_deps = [ - "//jaxlib/mlir:ir", - ], - pytype_srcs = ["_triton_ext.pyi"], - visibility = ["//visibility:private"], - deps = [ - ":triton_dialect_capi", - "@llvm-project//mlir:CAPIIR", - "@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps", - "@pybind11", - ], -) - cc_library( name = "triton_dialect_capi", srcs = ["triton_dialect_capi.cc"], @@ -127,3 +111,26 @@ cc_library( "@triton//:TritonDialects", ], ) + +# Header-only target, used when using the C API from a separate shared library. +cc_library( + name = "triton_dialect_capi_headers", + hdrs = ["triton_dialect_capi.h"], + deps = [ + "@llvm-project//mlir:CAPIIRHeaders", + ], +) + +# Alwayslink target, used when exporting the C API from a shared library. +cc_library( + name = "triton_dialect_capi_objects", + srcs = ["triton_dialect_capi.cc"], + hdrs = ["triton_dialect_capi.h"], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:IR", + "@triton//:TritonDialects", + ], + alwayslink = True, +) diff --git a/jaxlib/triton/dialect.py b/jaxlib/triton/dialect.py index 96f867eef..315f35265 100644 --- a/jaxlib/triton/dialect.py +++ b/jaxlib/triton/dialect.py @@ -15,19 +15,19 @@ # ruff: noqa """Python bindings for the MLIR Triton dialect.""" + from __future__ import annotations from collections.abc import Sequence -from typing import Any -from jaxlib.mlir import ir - -from ._triton_enum_gen import * # pylint: disable=wildcard-import -from ._triton_ext import ( +from jaxlib.mlir._mlir_libs._triton_ext import ( PointerType, infer_reduce_op_encoding, register_dialect, ) +from jaxlib.mlir import ir + +from ._triton_enum_gen import * # pylint: disable=wildcard-import from ._triton_ops_gen import * # pylint: disable=wildcard-import