Sergei Lebedev 5e2e609a9b _triton_ext no longer links in MLIR C APIs
I re-used the same trick we do for the TPU dialect. Specifically, _triton_ext no longer depends on :triton_dialect_capi. Instead

* we include Triton dialect C bindings into :jaxlib_mlir_capi_objects
* and _triton_ext depends on :jaxlib_mlir_capi_objects and a header-only cc_library providing Triton dialect C bindings

This is a fork of #19680 with a few internal-only fixes.

PiperOrigin-RevId: 604929377
2024-02-07 03:39:29 -08:00

137 lines
3.4 KiB
Python

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("//jaxlib:jax.bzl", "py_deps", "pytype_strict_library")
load("@rules_python//python:defs.bzl", "py_library")
load("@llvm-project//mlir:tblgen.bzl", "gentbl_filegroup")
licenses(["notice"])
package(
default_applicable_licenses = [],
default_visibility = ["//:__subpackages__"],
)
py_library(
name = "compat",
srcs = ["compat.py"],
deps = [
":triton",
"//jaxlib/mlir:arithmetic_dialect",
"//jaxlib/mlir:ir",
"//jaxlib/mlir:math_dialect",
"//jaxlib/mlir:scf_dialect",
] + py_deps(
"numpy",
# Users must add a Triton dependency manually.
),
)
pytype_strict_library(
name = "triton",
srcs = [
"__init__.py",
"dialect.py",
":_triton_gen",
],
visibility = ["//visibility:public"],
deps = [
"//jaxlib/mlir:core",
"//jaxlib/mlir:ir",
"//jaxlib/mlir/_mlir_libs:_triton_ext",
],
)
genrule(
name = "_triton_gen",
srcs = [
"_triton_ops_gen_raw.py",
"_triton_enum_gen_raw.py",
],
outs = [
"_triton_ops_gen.py",
"_triton_enum_gen.py",
],
cmd = """
for src in $(SRCS); do
out=$${src//_raw/}
echo '# pytype: skip-file' > $${out} && \
cat $${src} |
sed -e 's/^from \\.\\./from mlir\\./g' |
sed -e 's/^from \\./from mlir\\.dialects\\./g' >> $${out}
done
""",
)
gentbl_filegroup(
name = "_triton_gen_raw",
tbl_outs = [
(
[
"-gen-python-enum-bindings",
"-bind-dialect=tt",
],
"_triton_enum_gen_raw.py",
),
(
[
"-gen-python-op-bindings",
"-bind-dialect=tt",
],
"_triton_ops_gen_raw.py",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "triton.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@triton//:td_files",
],
)
cc_library(
name = "triton_dialect_capi",
srcs = ["triton_dialect_capi.cc"],
hdrs = ["triton_dialect_capi.h"],
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:IR",
"@triton//:TritonDialects",
],
)
# Header-only target, used when using the C API from a separate shared library.
cc_library(
name = "triton_dialect_capi_headers",
hdrs = ["triton_dialect_capi.h"],
deps = [
"@llvm-project//mlir:CAPIIRHeaders",
],
)
# Alwayslink target, used when exporting the C API from a shared library.
cc_library(
name = "triton_dialect_capi_objects",
srcs = ["triton_dialect_capi.cc"],
hdrs = ["triton_dialect_capi.h"],
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:IR",
"@triton//:TritonDialects",
],
alwayslink = True,
)