_triton_ext no longer links in MLIR C APIs

I re-used the same trick we do for the TPU dialect. Specifically, _triton_ext no longer depends on :triton_dialect_capi. Instead

* we include Triton dialect C bindings into :jaxlib_mlir_capi_objects
* and _triton_ext depends on :jaxlib_mlir_capi_objects and a header-only cc_library providing Triton dialect C bindings

This is a fork of #19680 with a few internal-only fixes.

PiperOrigin-RevId: 604929377
This commit is contained in:
Sergei Lebedev 2024-02-07 03:38:39 -08:00 committed by jax authors
parent 838e4e8dc6
commit 5e2e609a9b
6 changed files with 49 additions and 25 deletions

View File

@ -15,6 +15,7 @@
load(
"//jaxlib:jax.bzl",
"py_extension",
"pybind_extension",
"windows_cc_shared_mlir_library",
)
load("//jaxlib:symlink_files.bzl", "symlink_inputs")
@ -117,6 +118,21 @@ py_library(
],
)
pybind_extension(
name = "_triton_ext",
srcs = ["triton_ext.cc"],
copts = COPTS,
linkopts = LINKOPTS,
pytype_srcs = ["_triton_ext.pyi"],
deps = [
":jaxlib_mlir_capi_shared_library",
"//jaxlib/triton:triton_dialect_capi_headers",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps",
"@pybind11",
],
)
symlink_inputs(
name = "_mlir_libs",
rule = py_library,
@ -234,6 +250,7 @@ cc_library(
name = "jaxlib_mlir_capi_objects",
deps = [
"//jaxlib/mosaic:tpu_dialect_capi_objects",
"//jaxlib/triton:triton_dialect_capi_objects",
"@llvm-project//mlir:CAPIArithObjects",
"@llvm-project//mlir:CAPIIRObjects",
"@llvm-project//mlir:CAPIMathObjects",

View File

@ -327,6 +327,8 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi
f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/_triton_ext.{pyext}",
"__main__/jaxlib/mlir/_mlir_libs/_triton_ext.pyi",
f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}",
f"__main__/jaxlib/mlir/_mlir_libs/register_jax_dialects.{pyext}",
],
@ -339,8 +341,6 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi
"__main__/jaxlib/triton/__init__.py",
"__main__/jaxlib/triton/compat.py",
"__main__/jaxlib/triton/dialect.py",
f"__main__/jaxlib/triton/_triton_ext.{pyext}",
"__main__/jaxlib/triton/_triton_ext.pyi",
],
)
patch_copy_mlir_import(

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("//jaxlib:jax.bzl", "py_deps", "pybind_extension", "pytype_strict_library")
load("//jaxlib:jax.bzl", "py_deps", "pytype_strict_library")
load("@rules_python//python:defs.bzl", "py_library")
load("@llvm-project//mlir:tblgen.bzl", "gentbl_filegroup")
@ -47,9 +47,9 @@ pytype_strict_library(
],
visibility = ["//visibility:public"],
deps = [
":_triton_ext",
"//jaxlib/mlir:core",
"//jaxlib/mlir:ir",
"//jaxlib/mlir/_mlir_libs:_triton_ext",
],
)
@ -100,22 +100,6 @@ gentbl_filegroup(
],
)
pybind_extension(
name = "_triton_ext",
srcs = ["_triton_ext.cc"],
pytype_deps = [
"//jaxlib/mlir:ir",
],
pytype_srcs = ["_triton_ext.pyi"],
visibility = ["//visibility:private"],
deps = [
":triton_dialect_capi",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps",
"@pybind11",
],
)
cc_library(
name = "triton_dialect_capi",
srcs = ["triton_dialect_capi.cc"],
@ -127,3 +111,26 @@ cc_library(
"@triton//:TritonDialects",
],
)
# Header-only target, used when using the C API from a separate shared library.
cc_library(
name = "triton_dialect_capi_headers",
hdrs = ["triton_dialect_capi.h"],
deps = [
"@llvm-project//mlir:CAPIIRHeaders",
],
)
# Alwayslink target, used when exporting the C API from a shared library.
cc_library(
name = "triton_dialect_capi_objects",
srcs = ["triton_dialect_capi.cc"],
hdrs = ["triton_dialect_capi.h"],
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:IR",
"@triton//:TritonDialects",
],
alwayslink = True,
)

View File

@ -15,19 +15,19 @@
# ruff: noqa
"""Python bindings for the MLIR Triton dialect."""
from __future__ import annotations
from collections.abc import Sequence
from typing import Any
from jaxlib.mlir import ir
from ._triton_enum_gen import * # pylint: disable=wildcard-import
from ._triton_ext import (
from jaxlib.mlir._mlir_libs._triton_ext import (
PointerType,
infer_reduce_op_encoding,
register_dialect,
)
from jaxlib.mlir import ir
from ._triton_enum_gen import * # pylint: disable=wildcard-import
from ._triton_ops_gen import * # pylint: disable=wildcard-import