mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
_triton_ext no longer links in MLIR C APIs
I re-used the same trick we do for the TPU dialect. Specifically, _triton_ext no longer depends on :triton_dialect_capi. Instead * we include Triton dialect C bindings into :jaxlib_mlir_capi_objects * and _triton_ext depends on :jaxlib_mlir_capi_objects and a header-only cc_library providing Triton dialect C bindings This is a fork of #19680 with a few internal-only fixes. PiperOrigin-RevId: 604929377
This commit is contained in:
parent
838e4e8dc6
commit
5e2e609a9b
@ -15,6 +15,7 @@
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"py_extension",
|
||||
"pybind_extension",
|
||||
"windows_cc_shared_mlir_library",
|
||||
)
|
||||
load("//jaxlib:symlink_files.bzl", "symlink_inputs")
|
||||
@ -117,6 +118,21 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_triton_ext",
|
||||
srcs = ["triton_ext.cc"],
|
||||
copts = COPTS,
|
||||
linkopts = LINKOPTS,
|
||||
pytype_srcs = ["_triton_ext.pyi"],
|
||||
deps = [
|
||||
":jaxlib_mlir_capi_shared_library",
|
||||
"//jaxlib/triton:triton_dialect_capi_headers",
|
||||
"@llvm-project//mlir:CAPIIRHeaders",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
symlink_inputs(
|
||||
name = "_mlir_libs",
|
||||
rule = py_library,
|
||||
@ -234,6 +250,7 @@ cc_library(
|
||||
name = "jaxlib_mlir_capi_objects",
|
||||
deps = [
|
||||
"//jaxlib/mosaic:tpu_dialect_capi_objects",
|
||||
"//jaxlib/triton:triton_dialect_capi_objects",
|
||||
"@llvm-project//mlir:CAPIArithObjects",
|
||||
"@llvm-project//mlir:CAPIIRObjects",
|
||||
"@llvm-project//mlir:CAPIMathObjects",
|
||||
|
@ -327,6 +327,8 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_triton_ext.{pyext}",
|
||||
"__main__/jaxlib/mlir/_mlir_libs/_triton_ext.pyi",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}",
|
||||
f"__main__/jaxlib/mlir/_mlir_libs/register_jax_dialects.{pyext}",
|
||||
],
|
||||
@ -339,8 +341,6 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi
|
||||
"__main__/jaxlib/triton/__init__.py",
|
||||
"__main__/jaxlib/triton/compat.py",
|
||||
"__main__/jaxlib/triton/dialect.py",
|
||||
f"__main__/jaxlib/triton/_triton_ext.{pyext}",
|
||||
"__main__/jaxlib/triton/_triton_ext.pyi",
|
||||
],
|
||||
)
|
||||
patch_copy_mlir_import(
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("//jaxlib:jax.bzl", "py_deps", "pybind_extension", "pytype_strict_library")
|
||||
load("//jaxlib:jax.bzl", "py_deps", "pytype_strict_library")
|
||||
load("@rules_python//python:defs.bzl", "py_library")
|
||||
load("@llvm-project//mlir:tblgen.bzl", "gentbl_filegroup")
|
||||
|
||||
@ -47,9 +47,9 @@ pytype_strict_library(
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":_triton_ext",
|
||||
"//jaxlib/mlir:core",
|
||||
"//jaxlib/mlir:ir",
|
||||
"//jaxlib/mlir/_mlir_libs:_triton_ext",
|
||||
],
|
||||
)
|
||||
|
||||
@ -100,22 +100,6 @@ gentbl_filegroup(
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_triton_ext",
|
||||
srcs = ["_triton_ext.cc"],
|
||||
pytype_deps = [
|
||||
"//jaxlib/mlir:ir",
|
||||
],
|
||||
pytype_srcs = ["_triton_ext.pyi"],
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":triton_dialect_capi",
|
||||
"@llvm-project//mlir:CAPIIR",
|
||||
"@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "triton_dialect_capi",
|
||||
srcs = ["triton_dialect_capi.cc"],
|
||||
@ -127,3 +111,26 @@ cc_library(
|
||||
"@triton//:TritonDialects",
|
||||
],
|
||||
)
|
||||
|
||||
# Header-only target, used when using the C API from a separate shared library.
|
||||
cc_library(
|
||||
name = "triton_dialect_capi_headers",
|
||||
hdrs = ["triton_dialect_capi.h"],
|
||||
deps = [
|
||||
"@llvm-project//mlir:CAPIIRHeaders",
|
||||
],
|
||||
)
|
||||
|
||||
# Alwayslink target, used when exporting the C API from a shared library.
|
||||
cc_library(
|
||||
name = "triton_dialect_capi_objects",
|
||||
srcs = ["triton_dialect_capi.cc"],
|
||||
hdrs = ["triton_dialect_capi.h"],
|
||||
deps = [
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:CAPIIR",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@triton//:TritonDialects",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
@ -15,19 +15,19 @@
|
||||
# ruff: noqa
|
||||
|
||||
"""Python bindings for the MLIR Triton dialect."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
from ._triton_enum_gen import * # pylint: disable=wildcard-import
|
||||
from ._triton_ext import (
|
||||
from jaxlib.mlir._mlir_libs._triton_ext import (
|
||||
PointerType,
|
||||
infer_reduce_op_encoding,
|
||||
register_dialect,
|
||||
)
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
from ._triton_enum_gen import * # pylint: disable=wildcard-import
|
||||
from ._triton_ops_gen import * # pylint: disable=wildcard-import
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user